diff --git a/src/meshcore/ble_cx.py b/src/meshcore/ble_cx.py index 0ce06d9..0a7380f 100644 --- a/src/meshcore/ble_cx.py +++ b/src/meshcore/ble_cx.py @@ -51,6 +51,14 @@ class BLEConnection: self.pin = pin self.rx_char = None self._disconnect_callback = None + self._background_tasks: set[asyncio.Task] = set() + + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task async def connect(self): """ @@ -116,9 +124,12 @@ class BLEConnection: await self.client.pair() logger.info("BLE pairing successful") except Exception as e: - logger.warning(f"BLE pairing failed: {e}") - # Don't fail the connection if pairing fails, as the device - # might already be paired or not require pairing + logger.error(f"BLE pairing failed: {e}") + # A failed pairing leaves the transport in a half-usable + # state — re-raise so the caller gets a clean failure + # instead of a silently degraded connection. + await self.client.disconnect() + raise except BleakDeviceNotFoundError: return None @@ -154,8 +165,19 @@ class BLEConnection: self.client = self._user_provided_client self.device = self._user_provided_device + # Re-register disconnect callback on the reset client so subsequent + # disconnects after a reconnect cycle are still detected. + if self.client is not None and hasattr(self.client, 'set_disconnected_callback'): + try: + self.client.set_disconnected_callback(self.handle_disconnect) + except Exception: + # set_disconnected_callback may not be available on all bleak + # versions; the next connect() call will re-create the client + # with the callback anyway. + pass + if self._disconnect_callback: - asyncio.create_task(self._disconnect_callback("ble_disconnect")) + self._spawn_background(self._disconnect_callback("ble_disconnect")) def set_disconnect_callback(self, callback): """Set callback to handle disconnections.""" @@ -166,16 +188,24 @@ class BLEConnection: def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray): if self.reader is not None: - asyncio.create_task(self.reader.handle_rx(data)) + self._spawn_background(self.reader.handle_rx(data)) async def send(self, data): if not self.client: logger.error("Client is not connected") + if self._disconnect_callback: + await self._disconnect_callback("ble_transport_lost") return False if not self.rx_char: logger.error("RX characteristic not found") return False - await self.client.write_gatt_char(self.rx_char, bytes(data), response=True) + try: + await self.client.write_gatt_char(self.rx_char, bytes(data), response=True) + except Exception as exc: + logger.warning(f"BLE write failed: {exc}") + if self._disconnect_callback: + await self._disconnect_callback(f"ble_write_failed: {exc}") + return False async def disconnect(self): """Disconnect from the BLE device.""" diff --git a/src/meshcore/commands/base.py b/src/meshcore/commands/base.py index 9e0f00e..c0e4aae 100644 --- a/src/meshcore/commands/base.py +++ b/src/meshcore/commands/base.py @@ -58,17 +58,32 @@ def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes class CommandHandlerBase: + """Base class for command handlers. + + .. note:: + The internal ``asyncio.Lock`` is created lazily on first access + so that it binds to the correct running event loop (required for + Python 3.9/3.10 compatibility). + """ + DEFAULT_TIMEOUT = 5.0 def __init__(self, default_timeout: Optional[float] = None): self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None self._reader: Optional[MessageReader] = None self.dispatcher: Optional[EventDispatcher] = None - self._mesh_request_lock = asyncio.Lock() + self.__mesh_request_lock: Optional[asyncio.Lock] = None self.default_timeout = ( default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT ) + @property + def _mesh_request_lock(self) -> asyncio.Lock: + """Lazy-init lock so it binds to the running loop, not import-time.""" + if self.__mesh_request_lock is None: + self.__mesh_request_lock = asyncio.Lock() + return self.__mesh_request_lock + def set_connection(self, connection: Any) -> None: async def sender(data: bytes) -> None: await connection.send(data) @@ -90,6 +105,14 @@ class CommandHandlerBase: expected_events: Optional[Union[EventType, List[EventType]]] = None, timeout: Optional[float] = None, ) -> Event: + """Wait for the first of *expected_events* to arrive. + + Returns the first matched ``Event``. When ``EventType.ERROR`` is + among the expected types, the caller **must** check + ``result.is_error()`` before accessing command-specific payload + keys — an ERROR payload is ``{"reason": "..."}`` and will + ``KeyError`` on any other key. + """ try: # Convert single event to list if needed if not isinstance(expected_events, list): @@ -129,9 +152,6 @@ class CommandHandlerBase: logger.debug(f"Command error: {e}") return Event(EventType.ERROR, {"error": str(e)}) - return Event(EventType.ERROR, {}) - - async def send( self, data: bytes, @@ -151,7 +171,14 @@ class CommandHandlerBase: timeout: Timeout in seconds, or None to use default_timeout Returns: - Event: The full event object that was received in response to the command + Event: The full event object that was received in response to + the command. + + Important: + When ``EventType.ERROR`` is included in *expected_events*, the + returned event may be an error response. Callers **must** + check ``result.is_error()`` before accessing command-specific + payload keys to avoid ``KeyError``. """ if not self.dispatcher: raise RuntimeError("Dispatcher not set, cannot send commands") @@ -170,7 +197,7 @@ class CommandHandlerBase: futures: List[asyncio.Future] = [] subscriptions = [] - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() for event_type in expected_events: future = loop.create_future() @@ -266,6 +293,7 @@ class CommandHandlerBase: contact = self._get_contact_by_prefix(dst_bytes.hex()) # need a contact for return path if contact is None: logger.error("No contact found") + return Event(EventType.ERROR, {"reason": "contact_not_found"}) zero_hop = False if contact["out_path_len"] == -1: diff --git a/src/meshcore/commands/contact.py b/src/meshcore/commands/contact.py index ac723ea..8a89101 100644 --- a/src/meshcore/commands/contact.py +++ b/src/meshcore/commands/contact.py @@ -185,6 +185,24 @@ class ContactCommands(CommandHandlerBase): data = b"\x3B" return await self.send(data, [EventType.AUTOADD_CONFIG, EventType.ERROR]) + async def get_contact_by_key(self, pubkey: bytes) -> Event: + """N09: Retrieve a single contact by its public key (CMD 30). + + Args: + pubkey: 32-byte public key of the contact. + + Returns: + Event with the contact data (same format as CONTACT/NEXT_CONTACT), + or ERROR if not found. + """ + if not isinstance(pubkey, (bytes, bytearray)): + raise TypeError("pubkey must be bytes-like") + # Truncate or pad to 32 bytes + key_bytes = bytes(pubkey[:32]) + logger.debug(f"Getting contact by key: {key_bytes.hex()}") + data = b"\x1e" + key_bytes + return await self.send(data, [EventType.NEXT_CONTACT, EventType.ERROR]) + async def get_advert_path(self, key: DestinationType) -> Event: key_bytes = _validate_destination(key, prefix_length=32) logger.debug(f"getting advert path for: {key} {key_bytes.hex()}") diff --git a/src/meshcore/commands/device.py b/src/meshcore/commands/device.py index af986db..6175dc8 100644 --- a/src/meshcore/commands/device.py +++ b/src/meshcore/commands/device.py @@ -4,6 +4,7 @@ from hashlib import sha256 from typing import Optional from ..events import Event, EventType +from ..packets import CommandType from .base import CommandHandlerBase, DestinationType, _validate_destination logger = logging.getLogger("meshcore") @@ -13,7 +14,7 @@ class DeviceCommands(CommandHandlerBase): async def send_appstart(self) -> Event: logger.debug("Sending appstart command") b1 = bytearray(b"\x01\x03 mccli") - return await self.send(b1, [EventType.SELF_INFO]) + return await self.send(b1, [EventType.SELF_INFO, EventType.ERROR]) async def send_device_query(self) -> Event: logger.debug("Sending device query command") @@ -129,32 +130,50 @@ class DeviceCommands(CommandHandlerBase): return await self.send(data, [EventType.OK, EventType.ERROR]) async def set_telemetry_mode_base(self, telemetry_mode_base: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["telemetry_mode_base"] = telemetry_mode_base return await self.set_other_params_from_infos(infos) async def set_telemetry_mode_loc(self, telemetry_mode_loc: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["telemetry_mode_loc"] = telemetry_mode_loc return await self.set_other_params_from_infos(infos) async def set_telemetry_mode_env(self, telemetry_mode_env: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["telemetry_mode_env"] = telemetry_mode_env return await self.set_other_params_from_infos(infos) async def set_manual_add_contacts(self, manual_add_contacts: bool) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["manual_add_contacts"] = manual_add_contacts return await self.set_other_params_from_infos(infos) async def set_advert_loc_policy(self, advert_loc_policy: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["adv_loc_policy"] = advert_loc_policy return await self.set_other_params_from_infos(infos) async def set_multi_acks(self, multi_acks: int) -> Event: - infos = (await self.send_appstart()).payload + result = await self.send_appstart() + if result.is_error(): + return result + infos = result.payload infos["multi_acks"] = multi_acks return await self.set_other_params_from_infos(infos) @@ -273,20 +292,89 @@ class DeviceCommands(CommandHandlerBase): return await self.sign_finish(timeout=timeout, data_size=len(data)) + async def has_connection(self) -> Event: + """N09: Check if the device has an active connection (CMD 28). + + Returns: + Event with a 1-byte response indicating connection status, + or ERROR. + """ + logger.debug("Checking device connection status") + return await self.send(b"\x1c", [EventType.OK, EventType.ERROR]) + + async def get_tuning(self) -> Event: + """N03/N09: Request current tuning parameters (CMD_GET_TUNING_PARAMS = 43). + + Firmware responds with RESP_CODE_TUNING_PARAMS (23): 9 bytes containing + rx_delay (4 bytes LE) and airtime_factor (4 bytes LE). + + Returns: + Event of type TUNING_PARAMS with rx_delay and airtime_factor, + or ERROR. + """ + logger.debug("Getting tuning parameters") + return await self.send(b"\x2b", [EventType.TUNING_PARAMS, EventType.ERROR]) + + async def request_factory_reset(self) -> str: + """N09: Request a factory reset token (step 1 of 2). + + This method returns a confirmation token string. Pass it to + ``confirm_factory_reset(token)`` to actually execute the reset. + The two-step pattern is a Python-side safety measure; the firmware + itself has no token verification. + + Returns: + A confirmation token string to pass to confirm_factory_reset(). + """ + import secrets + token = secrets.token_hex(8) + logger.warning( + "Factory reset requested. Call confirm_factory_reset('%s') to proceed. " + "This will ERASE ALL DATA on the device.", token + ) + # Store the token on the instance for validation + self._factory_reset_token = token + return token + + async def confirm_factory_reset(self, token: str) -> Event: + """N09: Execute factory reset after token confirmation (step 2 of 2). + + Args: + token: The token returned by request_factory_reset(). + + Returns: + Event with OK or ERROR. + + Raises: + ValueError: If the token does not match. + """ + expected = getattr(self, "_factory_reset_token", None) + if expected is None or token != expected: + raise ValueError( + "Invalid or expired factory reset token. " + "Call request_factory_reset() first." + ) + self._factory_reset_token = None # Consume the token + logger.warning("Executing factory reset — all device data will be erased") + return await self.send(b"\x33", [EventType.OK, EventType.ERROR]) + async def get_stats_core(self) -> Event: logger.debug("Getting core statistics") - # CMD_GET_STATS (56) + STATS_TYPE_CORE (0) - return await self.send(b"\x38\x00", [EventType.STATS_CORE, EventType.ERROR]) + # R04: Use CommandType enum instead of literal bytes + cmd = bytes([CommandType.GET_STATS.value, 0x00]) # GET_STATS + STATS_TYPE_CORE + return await self.send(cmd, [EventType.STATS_CORE, EventType.ERROR]) async def get_stats_radio(self) -> Event: logger.debug("Getting radio statistics") - # CMD_GET_STATS (56) + STATS_TYPE_RADIO (1) - return await self.send(b"\x38\x01", [EventType.STATS_RADIO, EventType.ERROR]) + # R04: Use CommandType enum instead of literal bytes + cmd = bytes([CommandType.GET_STATS.value, 0x01]) # GET_STATS + STATS_TYPE_RADIO + return await self.send(cmd, [EventType.STATS_RADIO, EventType.ERROR]) async def get_stats_packets(self) -> Event: logger.debug("Getting packet statistics") - # CMD_GET_STATS (56) + STATS_TYPE_PACKETS (2) - return await self.send(b"\x38\x02", [EventType.STATS_PACKETS, EventType.ERROR]) + # R04: Use CommandType enum instead of literal bytes + cmd = bytes([CommandType.GET_STATS.value, 0x02]) # GET_STATS + STATS_TYPE_PACKETS + return await self.send(cmd, [EventType.STATS_PACKETS, EventType.ERROR]) async def get_allowed_repeat_freq(self) -> Event: logger.debug("Getting allowed repeat freqs") diff --git a/src/meshcore/commands/messaging.py b/src/meshcore/commands/messaging.py index b266ae0..604110b 100644 --- a/src/meshcore/commands/messaging.py +++ b/src/meshcore/commands/messaging.py @@ -144,8 +144,12 @@ class MessagingCommands(CommandHandlerBase): logger.info(f"Retry sending msg: {attempts + 1}") result = await self.send_msg(dst, msg, timestamp, attempt=attempts) - if result.type == EventType.ERROR: - logger.error(f"⚠️ Failed to send message: {result.payload}") + if result.is_error(): + logger.error(f"Failed to send message: {result.payload}") + attempts += 1 + if flood: + flood_attempts += 1 + continue exp_ack = result.payload["expected_ack"].hex() timeout = result.payload["suggested_timeout"] / 1000 * 1.2 if timeout==0 else timeout @@ -255,7 +259,7 @@ class MessagingCommands(CommandHandlerBase): elif path_hash_len == 8 : flags = 3 else : - logger.error(f"Invalid path format: {e}") + logger.error(f"Invalid path format: unknown path_hash_len {path_hash_len}") return Event(EventType.ERROR, {"reason": "invalid_path_format"}) else: flags = 0 @@ -291,12 +295,34 @@ class MessagingCommands(CommandHandlerBase): cmd_data.append(flags) cmd_data.extend(path_bytes) + # N05: Firmware requires strict len > 10 (MyMesh.cpp:1620). + # When path is empty, cmd(1)+tag(4)+auth(4)+flags(1) = 10 bytes exactly, + # which is silently rejected. Pad with one zero byte to reach 11. + if len(cmd_data) <= 10: + cmd_data.append(0x00) + logger.debug( f"Sending trace: tag={tag}, auth={auth_code}, flags={flags}, path={path_bytes.hex()}" ) return await self.send(cmd_data, [EventType.MSG_SENT, EventType.ERROR]) + async def send_raw_data(self, payload: bytes) -> Event: + """N09: Send raw data via CMD_SEND_RAW_DATA (25). + + Sends an arbitrary payload through the mesh network. + + Args: + payload: Raw bytes to send. + + Returns: + Event with MSG_SENT or ERROR. + """ + if not isinstance(payload, (bytes, bytearray)): + raise TypeError("payload must be bytes-like") + data = b"\x19" + bytes(payload) + return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) + async def set_flood_scope(self, scope): if scope is None: logger.debug(f"Resetting scope") diff --git a/src/meshcore/connection_manager.py b/src/meshcore/connection_manager.py index c95ec37..bcd09ff 100644 --- a/src/meshcore/connection_manager.py +++ b/src/meshcore/connection_manager.py @@ -4,14 +4,23 @@ Connection manager that orchestrates reconnection logic for any connection type. import asyncio import logging -from typing import Optional, Any, Callable, Protocol +from typing import Optional, Any, Awaitable, Callable, Protocol from .events import Event, EventType logger = logging.getLogger("meshcore") class ConnectionProtocol(Protocol): - """Protocol defining the interface that connection classes must implement.""" + """Protocol defining the interface that connection classes must implement. + + Return contract for connect(): + - On success: return a truthy value (typically an address string) + that identifies the connection. This value is included in the + CONNECTED event payload as ``connection_info``. + - On failure: return ``None`` (soft failure — triggers a retry in + ``_attempt_reconnect``) **or** raise an exception (hard failure — + also triggers a retry, logged as an error). + """ async def connect(self) -> Optional[Any]: """Connect and return connection info, or None if failed.""" @@ -39,11 +48,13 @@ class ConnectionManager: event_dispatcher=None, auto_reconnect: bool = False, max_reconnect_attempts: int = 3, + reconnect_callback: Optional[Callable[[], Awaitable[None]]] = None, ): self.connection = connection self.event_dispatcher = event_dispatcher self.auto_reconnect = auto_reconnect self.max_reconnect_attempts = max_reconnect_attempts + self._reconnect_callback = reconnect_callback self._reconnect_attempts = 0 self._is_connected = False @@ -109,45 +120,51 @@ class ConnectionManager: ) async def _attempt_reconnect(self): - """Attempt to reconnect with flat delay.""" - logger.debug( - f"Attempting reconnection ({self._reconnect_attempts + 1}/{self.max_reconnect_attempts})" - ) - self._reconnect_attempts += 1 + """Attempt to reconnect using an iterative loop. - # Flat 1 second delay for all attempts - await asyncio.sleep(1) + Runs as a single persistent task for the entire reconnect session. + Previous implementation used tail-recursion via create_task which + orphaned the running task reference — disconnect() could only cancel + the newest pointer, leaving earlier attempts in flight (F03). + """ + while self._reconnect_attempts < self.max_reconnect_attempts: + self._reconnect_attempts += 1 + logger.debug( + f"Attempting reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})" + ) + + # Flat 1 second delay for all attempts + await asyncio.sleep(1) + + try: + result = await self.connection.connect() + if result is not None: + self._is_connected = True + self._reconnect_attempts = 0 + + # Invoke reconnect callback (e.g. send_appstart) if provided + if self._reconnect_callback is not None: + try: + await self._reconnect_callback() + except Exception as cb_err: + logger.warning( + f"Reconnect callback failed: {cb_err}" + ) - try: - result = await self.connection.connect() - if result is not None: - self._is_connected = True - self._reconnect_attempts = 0 - await self._emit_event( - EventType.CONNECTED, - {"connection_info": result, "reconnected": True}, - ) - logger.debug("Reconnected successfully") - else: - # Reconnection failed, try again if we haven't exceeded max attempts - if self._reconnect_attempts < self.max_reconnect_attempts: - self._reconnect_task = asyncio.create_task( - self._attempt_reconnect() - ) - else: await self._emit_event( - EventType.DISCONNECTED, - {"reason": "reconnect_failed", "max_attempts_exceeded": True}, + EventType.CONNECTED, + {"connection_info": result, "reconnected": True}, ) - except Exception as e: - logger.debug(f"Reconnection attempt failed: {e}") - if self._reconnect_attempts < self.max_reconnect_attempts: - self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) - else: - await self._emit_event( - EventType.DISCONNECTED, - {"reason": f"reconnect_error: {e}", "max_attempts_exceeded": True}, - ) + logger.debug("Reconnected successfully") + return + except Exception as e: + logger.debug(f"Reconnection attempt failed: {e}") + + # All attempts exhausted + await self._emit_event( + EventType.DISCONNECTED, + {"reason": "reconnect_failed", "max_attempts_exceeded": True}, + ) async def _emit_event(self, event_type: EventType, payload: dict): """Emit connection events if dispatcher is available.""" diff --git a/src/meshcore/events.py b/src/meshcore/events.py index f8a7521..940ee18 100644 --- a/src/meshcore/events.py +++ b/src/meshcore/events.py @@ -49,6 +49,9 @@ class EventType(Enum): PATH_RESPONSE = "path_response" PRIVATE_KEY = "private_key" DISABLED = "disabled" + CONTACT_DELETED = "contact_deleted" + CONTACTS_FULL = "contacts_full" + TUNING_PARAMS = "tuning_params" CONTROL_DATA = "control_data" DISCOVER_RESPONSE = "discover_response" NEIGHBOURS_RESPONSE = "neighbours_response" @@ -104,6 +107,17 @@ class Event: if kwargs: self.attributes.update(kwargs) + def is_error(self) -> bool: + """Return True if this event represents an error response. + + Callers that include ``EventType.ERROR`` in their expected-events + list **must** check ``result.is_error()`` (or ``result.type == + EventType.ERROR``) before accessing keyed payload fields, because + an ERROR payload contains ``{"reason": "..."}`` — not the + command-specific keys the caller expects on the happy path. + """ + return self.type == EventType.ERROR + def clone(self): """ Create a copy of the event. @@ -129,11 +143,28 @@ class Subscription: class EventDispatcher: + """Event dispatch engine. + + .. note:: + ``start()`` must be called before dispatching or processing events. + The internal ``asyncio.Queue`` is created lazily inside ``start()`` + so that it binds to the correct running event loop (required for + Python 3.9/3.10 compatibility). + """ + def __init__(self): - self.queue: asyncio.Queue[Event] = asyncio.Queue() + self.queue: Optional[asyncio.Queue[Event]] = None self.subscriptions: List[Subscription] = [] self.running = False self._task = None + self._background_tasks: set[asyncio.Task] = set() + + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task def subscribe( self, @@ -166,6 +197,10 @@ class EventDispatcher: self.subscriptions.remove(subscription) async def dispatch(self, event: Event): + if self.queue is None: + raise RuntimeError( + "EventDispatcher.start() must be called before dispatching events" + ) await self.queue.put(event) async def _process_events(self): @@ -197,7 +232,7 @@ class EventDispatcher: # returns - avoids the race where create_task schedules the callback after # the waiter has already timed out with done=set(). if asyncio.iscoroutinefunction(subscription.callback): - asyncio.create_task(self._execute_callback(subscription.callback, event.clone())) + self._spawn_background(self._execute_callback(subscription.callback, event.clone())) else: try: subscription.callback(event.clone()) @@ -220,6 +255,8 @@ class EventDispatcher: async def start(self): if not self.running: + if self.queue is None: + self.queue = asyncio.Queue() self.running = True self._task = asyncio.create_task(self._process_events()) @@ -227,7 +264,12 @@ class EventDispatcher: if self.running: self.running = False if self._task: - await self.queue.join() + if self.queue is not None: + await self.queue.join() + # Wait for any in-flight async callbacks to complete before + # tearing down (F07: task_done fires before callbacks finish). + if self._background_tasks: + await asyncio.gather(*self._background_tasks, return_exceptions=True) self._task.cancel() try: await self._task diff --git a/src/meshcore/meshcore.py b/src/meshcore/meshcore.py index bdd0db3..af1cfa4 100644 --- a/src/meshcore/meshcore.py +++ b/src/meshcore/meshcore.py @@ -28,10 +28,17 @@ class MeshCore: auto_reconnect: bool = False, max_reconnect_attempts: int = 3, ): - # Wrap connection with ConnectionManager + # Wrap connection with ConnectionManager. + # The reconnect callback ensures send_appstart() runs after every + # transport-level reconnect, which is required by firmware to + # initialize the session (F02). self.dispatcher = EventDispatcher() self.connection_manager = ConnectionManager( - cx, self.dispatcher, auto_reconnect, max_reconnect_attempts + cx, + self.dispatcher, + auto_reconnect, + max_reconnect_attempts, + reconnect_callback=self._on_reconnect, ) self.cx = self.connection_manager # For backward compatibility @@ -174,6 +181,15 @@ class MeshCore: return None return mc + async def _on_reconnect(self): + """Callback invoked by ConnectionManager after a successful reconnect. + + Firmware requires CMD_APP_START after every transport-level connection + to initialize the session. MeshCore.connect() does this on the initial + connection; this callback ensures it also happens on reconnects (F02). + """ + await self.commands.send_appstart() + async def connect(self): await self.dispatcher.start() result = await self.connection_manager.connect() diff --git a/src/meshcore/packets.py b/src/meshcore/packets.py index b5cef23..ad27fce 100644 --- a/src/meshcore/packets.py +++ b/src/meshcore/packets.py @@ -71,6 +71,7 @@ class CommandType(Enum): SET_AUTOADD_CONFIG = 58 GET_AUTOADD_CONFIG = 59 GET_ALLOWED_REPEAT_FREQ = 60 + GET_STATS = 56 # R04: CMD_GET_STATS — used by get_stats_core/radio/packets SET_PATH_HASH_MODE = 61 # Packet prefixes for the protocol @@ -120,3 +121,6 @@ class PacketType(Enum): PATH_DISCOVERY_RESPONSE = 0x8D CONTROL_DATA = 0x8E CONTACT_DELETED = 0x8F + CONTACTS_FULL = 0x90 # N02: MyMesh::onContactsFull() — 1-byte push, no payload + # Note: 0x90 == ControlType.NODE_DISCOVER_RESP in a different namespace. + # Not a literal conflict (PacketType vs ControlType), but a maintenance hazard. diff --git a/src/meshcore/reader.py b/src/meshcore/reader.py index 518eb5c..5ae2bce 100644 --- a/src/meshcore/reader.py +++ b/src/meshcore/reader.py @@ -552,9 +552,9 @@ class MessageReader: perms = dbuf.read(1)[0] res["permissions"] = perms res["is_admin"] = (perms & 1) == 1 # Check if admin bit is set - + if len(data) > 7: res["pubkey_prefix"] = dbuf.read(6).hex() - + attributes = {"pubkey_prefix": res.get("pubkey_prefix")} await self.dispatcher.dispatch( @@ -942,6 +942,36 @@ class MessageReader: await self.dispatcher.dispatch( Event(EventType.DISCOVER_RESPONSE, ndr, attributes) ) + elif packet_type_value == PacketType.CONTACT_DELETED.value: + # N01: PUSH_CODE_CONTACT_DELETED (0x8F) — 1-byte code + 32-byte pubkey + # Emitted by MyMesh::onContactOverwrite() (MyMesh.cpp:325-334) + if len(data) < 33: + logger.debug("CONTACT_DELETED frame too short (%d bytes, need 33)", len(data)) + return + pubkey = data[1:33].hex() + await self.dispatcher.dispatch( + Event(EventType.CONTACT_DELETED, {"pubkey": pubkey}, {"pubkey": pubkey}) + ) + + elif packet_type_value == PacketType.CONTACTS_FULL.value: + # N02: PUSH_CODE_CONTACTS_FULL (0x90) — 1-byte push, no payload + # Emitted by MyMesh::onContactsFull() (MyMesh.cpp:336) + await self.dispatcher.dispatch(Event(EventType.CONTACTS_FULL, {})) + + elif packet_type_value == PacketType.TUNING_PARAMS.value: + # N03: RESP_CODE_TUNING_PARAMS (23) — response to CMD_GET_TUNING_PARAMS (43) + # Format: 1-byte code + 4-byte rx_delay (LE) + 4-byte airtime_factor (LE) = 9 bytes + # Emitted by MyMesh.cpp:1307-1313 + if len(data) < 9: + logger.debug("TUNING_PARAMS frame too short (%d bytes, need 9)", len(data)) + await self.dispatcher.dispatch( + Event(EventType.ERROR, {"reason": "invalid_frame_length"}) + ) + return + rx_delay = int.from_bytes(data[1:5], byteorder="little") + airtime_factor = int.from_bytes(data[5:9], byteorder="little") + res = {"rx_delay": rx_delay, "airtime_factor": airtime_factor} + await self.dispatcher.dispatch(Event(EventType.TUNING_PARAMS, res)) else: logger.debug(f"Unhandled data received {data}") @@ -953,4 +983,4 @@ class MessageReader: e, data.hex(), traceback.format_exc(), - ) + ) \ No newline at end of file diff --git a/src/meshcore/serial_cx.py b/src/meshcore/serial_cx.py index 61163bd..088142b 100644 --- a/src/meshcore/serial_cx.py +++ b/src/meshcore/serial_cx.py @@ -20,11 +20,19 @@ class SerialConnection: self._disconnect_callback = None self.cx_dly = cx_dly self._connected_event = asyncio.Event() + self._background_tasks: set[asyncio.Task] = set() self.frame_expected_size = 0 self.inframe = b"" self.header = b"" + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task + class MCSerialClientProtocol(asyncio.Protocol): def __init__(self, cx): self.cx = cx @@ -44,7 +52,7 @@ class SerialConnection: self.cx._connected_event.clear() if self.cx._disconnect_callback: - asyncio.create_task(self.cx._disconnect_callback("serial_disconnect")) + self.cx._spawn_background(self.cx._disconnect_callback("serial_disconnect")) def pause_writing(self): logger.debug("pause writing") @@ -52,12 +60,16 @@ class SerialConnection: def resume_writing(self): logger.debug("resume writing") - async def connect(self): + async def connect(self, timeout: float = 10.0): """ - Connects to the device + Connects to the device. + + Args: + timeout: Maximum seconds to wait for connection_made callback. + Defaults to 10.0. Raises asyncio.TimeoutError on expiry. """ self._connected_event.clear() - + loop = asyncio.get_running_loop() await serial_asyncio.create_serial_connection( loop, @@ -66,7 +78,7 @@ class SerialConnection: baudrate=self.baudrate, ) - await self._connected_event.wait() + await asyncio.wait_for(self._connected_event.wait(), timeout=timeout) logger.info("Serial Connection started") return self.port @@ -102,7 +114,7 @@ class SerialConnection: self.frame_expected_size = 0 if len(data) > 0: # rerun handle_rx on remaining data self.handle_rx(data) - return + return # nothing left to process after reset upbound = self.frame_expected_size - len(self.inframe) if len(data) < upbound: @@ -114,7 +126,7 @@ class SerialConnection: data = data[upbound:] if self.reader is not None: # feed meshcore reader - asyncio.create_task(self.reader.handle_rx(self.inframe)) + self._spawn_background(self.reader.handle_rx(self.inframe)) # reset inframe self.inframe = b"" self.header = b"" @@ -125,11 +137,18 @@ class SerialConnection: async def send(self, data): if not self.transport: logger.error("Transport not connected, cannot send data") + if self._disconnect_callback: + await self._disconnect_callback("serial_transport_lost") return size = len(data) pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data logger.debug(f"sending pkt : {pkt}") - self.transport.write(pkt) + try: + self.transport.write(pkt) + except OSError as exc: + logger.warning(f"Serial write failed: {exc}") + if self._disconnect_callback: + await self._disconnect_callback(f"serial_write_failed: {exc}") async def disconnect(self): """Close the serial connection.""" diff --git a/src/meshcore/tcp_cx.py b/src/meshcore/tcp_cx.py index 497c3b2..80c8af3 100644 --- a/src/meshcore/tcp_cx.py +++ b/src/meshcore/tcp_cx.py @@ -24,6 +24,14 @@ class TCPConnection: self.frame_expected_size = 0 self.header = b"" self.inframe = b"" + self._background_tasks: set[asyncio.Task] = set() + + def _spawn_background(self, coro) -> asyncio.Task: + """Create a tracked background task (prevents GC of fire-and-forget tasks).""" + task = asyncio.create_task(coro) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return task class MCClientProtocol(asyncio.Protocol): def __init__(self, cx): @@ -38,7 +46,6 @@ class TCPConnection: def data_received(self, data): logger.debug("data received") - self.cx._receive_count += 1 self.cx.handle_rx(data) def error_received(self, exc): @@ -47,7 +54,7 @@ class TCPConnection: def connection_lost(self, exc): logger.debug("TCP server closed the connection") if self.cx._disconnect_callback: - asyncio.create_task(self.cx._disconnect_callback("tcp_disconnect")) + self.cx._spawn_background(self.cx._disconnect_callback("tcp_disconnect")) async def connect(self): """ @@ -59,10 +66,7 @@ class TCPConnection: ) logger.info("TCP Connection started") - future = asyncio.Future() - future.set_result(self.host) - - return future + return self.host def set_reader(self, reader): self.reader = reader @@ -96,7 +100,7 @@ class TCPConnection: self.frame_expected_size = 0 if len(data) > 0: # rerun handle_rx on remaining data self.handle_rx(data) - return + return # nothing left to process after reset upbound = self.frame_expected_size - len(self.inframe) if len(data) < upbound : @@ -106,9 +110,13 @@ class TCPConnection: self.inframe = self.inframe + data[0:upbound] data = data[upbound:] + # Increment per completed MeshCore frame, not per TCP segment (N04). + # The threshold heuristic in send() compares _send_count vs + # _receive_count — counting per-segment skews it under fragmentation. + self._receive_count += 1 if self.reader is not None: # feed meshcore reader - asyncio.create_task(self.reader.handle_rx(self.inframe)) + self._spawn_background(self.reader.handle_rx(self.inframe)) # reset inframe self.inframe = b"" self.header = b"" @@ -137,7 +145,12 @@ class TCPConnection: size = len(data) pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data logger.debug(f"sending pkt : {pkt}") - self.transport.write(pkt) + try: + self.transport.write(pkt) + except (OSError, ConnectionResetError) as exc: + logger.warning(f"TCP write failed: {exc}") + if self._disconnect_callback: + await self._disconnect_callback(f"tcp_write_failed: {exc}") async def disconnect(self): """Close the TCP connection.""" diff --git a/tests/test_ble_pin_pairing.py b/tests/test_ble_pin_pairing.py index aaa3acc..8b49184 100644 --- a/tests/test_ble_pin_pairing.py +++ b/tests/test_ble_pin_pairing.py @@ -37,7 +37,7 @@ class TestBLEPinPairing(unittest.TestCase): @patch("meshcore.ble_cx.BleakClient") def test_ble_connection_with_pin_failed_pairing(self, mock_bleak_client): - """Test BLE connection with PIN when pairing fails but connection continues""" + """Test BLE connection with PIN when pairing fails — re-raises (F17).""" # Arrange mock_client_instance = self._get_mock_bleak_client() mock_client_instance.pair = AsyncMock(side_effect=Exception("Pairing failed")) @@ -47,17 +47,16 @@ class TestBLEPinPairing(unittest.TestCase): pin = "123456" ble_conn = BLEConnection(address=address, pin=pin) - # Act - result = asyncio.run(ble_conn.connect()) - - # Assert + # Act & Assert — pairing failure now re-raises instead of being + # swallowed, because a half-usable transport is worse than a clean + # failure (forensics finding F17). + with self.assertRaises(Exception) as ctx: + asyncio.run(ble_conn.connect()) + self.assertIn("Pairing failed", str(ctx.exception)) mock_client_instance.connect.assert_called_once() mock_client_instance.pair.assert_called_once() - mock_client_instance.start_notify.assert_called_once_with( - UART_TX_CHAR_UUID, ble_conn.handle_rx - ) - # Connection should still succeed even if pairing fails - self.assertEqual(result, address) + # disconnect should be called to clean up the failed connection + mock_client_instance.disconnect.assert_called_once() @patch("meshcore.ble_cx.BleakClient") def test_ble_connection_without_pin_no_pairing(self, mock_bleak_client): diff --git a/tests/unit/test_asyncio_lifecycle.py b/tests/unit/test_asyncio_lifecycle.py new file mode 100644 index 0000000..f70f2f0 --- /dev/null +++ b/tests/unit/test_asyncio_lifecycle.py @@ -0,0 +1,235 @@ +""" +Verification tests for asyncio lifecycle fixes. +""" + +import asyncio +import gc +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from meshcore.events import Event, EventDispatcher, EventType +from meshcore.tcp_cx import TCPConnection +from meshcore.serial_cx import SerialConnection +from meshcore.commands.base import CommandHandlerBase + + +class TestBackgroundTaskTracking(unittest.TestCase): + """Fire-and-forget create_task calls must be tracked to prevent GC.""" + + def test_tcp_spawn_background_retains_task(self): + """TCP _spawn_background adds the task to _background_tasks.""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + completed = asyncio.Event() + + async def dummy(): + completed.set() + + task = cx._spawn_background(dummy()) + assert task in cx._background_tasks + await completed.wait() + # After completion, done_callback should have discarded it + await asyncio.sleep(0) # let done callback fire + assert task not in cx._background_tasks + + asyncio.run(_run()) + + def test_serial_spawn_background_retains_task(self): + """Serial _spawn_background adds the task to _background_tasks.""" + async def _run(): + with patch("meshcore.serial_cx.asyncio.Event") as mock_event: + mock_event.return_value = MagicMock() + cx = SerialConnection("/dev/null", 115200) + completed = asyncio.Event() + + async def dummy(): + completed.set() + + task = cx._spawn_background(dummy()) + assert task in cx._background_tasks + await completed.wait() + await asyncio.sleep(0) + assert task not in cx._background_tasks + + asyncio.run(_run()) + + def test_event_dispatcher_spawn_background_retains_task(self): + """EventDispatcher _spawn_background adds task to _background_tasks.""" + async def _run(): + dispatcher = EventDispatcher() + completed = asyncio.Event() + + async def dummy(): + completed.set() + + task = dispatcher._spawn_background(dummy()) + assert task in dispatcher._background_tasks + await completed.wait() + await asyncio.sleep(0) + assert task not in dispatcher._background_tasks + + asyncio.run(_run()) + + def test_tcp_handle_rx_uses_tracked_task(self): + """TCP handle_rx dispatches reader.handle_rx via _spawn_background.""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + reader = AsyncMock() + reader.handle_rx = AsyncMock() + cx.set_reader(reader) + + # Build a minimal valid frame: 0x3e + 2-byte LE size + payload + payload = b"\x01\x02\x03" + size = len(payload).to_bytes(2, "little") + frame = b"\x3e" + size + payload + + cx.handle_rx(frame) + # Task should be tracked + assert len(cx._background_tasks) == 1 + # Let task complete + await asyncio.sleep(0.05) + reader.handle_rx.assert_awaited_once_with(payload) + + asyncio.run(_run()) + + def test_tcp_connection_lost_uses_tracked_task(self): + """TCP connection_lost dispatches disconnect callback via _spawn_background.""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + callback = AsyncMock() + cx.set_disconnect_callback(callback) + + protocol = cx.MCClientProtocol(cx) + protocol.connection_lost(None) + + assert len(cx._background_tasks) == 1 + await asyncio.sleep(0.05) + callback.assert_awaited_once_with("tcp_disconnect") + + asyncio.run(_run()) + + def test_gc_does_not_cancel_tracked_tasks(self): + """Tracked tasks survive GC pressure (the whole point of tracking).""" + async def _run(): + cx = TCPConnection("127.0.0.1", 5555) + result = [] + + async def slow_task(): + await asyncio.sleep(0.05) + result.append("done") + + cx._spawn_background(slow_task()) + # Force GC — untracked tasks could be collected here + gc.collect() + await asyncio.sleep(0.1) + assert result == ["done"] + + asyncio.run(_run()) + + +class TestTaskDoneCorrectness(unittest.TestCase): + """EventDispatcher.stop() must wait for in-flight async callbacks.""" + + def test_stop_waits_for_async_callbacks(self): + """stop() should not return until async callbacks have completed.""" + async def _run(): + dispatcher = EventDispatcher() + await dispatcher.start() + + callback_completed = False + + async def slow_callback(event): + nonlocal callback_completed + await asyncio.sleep(0.1) + callback_completed = True + + dispatcher.subscribe(EventType.OK, slow_callback) + await dispatcher.dispatch(Event(EventType.OK, {})) + + # Give the dispatch loop a moment to pick up the event + await asyncio.sleep(0.02) + + # stop() should wait for slow_callback to finish + await dispatcher.stop() + assert callback_completed, "stop() returned before async callback completed" + + asyncio.run(_run()) + + +class TestDeferredPrimitiveConstruction(unittest.TestCase): + """Queue and Lock must not bind to import-time loop.""" + + def test_event_dispatcher_queue_is_none_before_start(self): + """EventDispatcher.queue should be None until start() is called.""" + dispatcher = EventDispatcher() + assert dispatcher.queue is None + + def test_event_dispatcher_queue_created_on_start(self): + """start() creates the queue.""" + async def _run(): + dispatcher = EventDispatcher() + assert dispatcher.queue is None + await dispatcher.start() + assert dispatcher.queue is not None + assert isinstance(dispatcher.queue, asyncio.Queue) + await dispatcher.stop() + + asyncio.run(_run()) + + def test_event_dispatcher_dispatch_before_start_raises(self): + """dispatch() before start() should raise RuntimeError.""" + async def _run(): + dispatcher = EventDispatcher() + with self.assertRaises(RuntimeError): + await dispatcher.dispatch(Event(EventType.OK, {})) + + asyncio.run(_run()) + + def test_command_handler_lock_is_none_before_use(self): + """CommandHandlerBase lock should be None until first access.""" + handler = CommandHandlerBase() + assert handler._CommandHandlerBase__mesh_request_lock is None + + def test_command_handler_lock_created_on_access(self): + """Accessing _mesh_request_lock creates it lazily.""" + async def _run(): + handler = CommandHandlerBase() + lock = handler._mesh_request_lock + assert isinstance(lock, asyncio.Lock) + # Second access returns same instance + assert handler._mesh_request_lock is lock + + asyncio.run(_run()) + + +class TestGetRunningLoop(unittest.TestCase): + """get_event_loop() replaced with get_running_loop() in send().""" + + def test_send_uses_get_running_loop(self): + """send() should call get_running_loop, not get_event_loop.""" + async def _run(): + handler = CommandHandlerBase() + dispatcher = EventDispatcher() + await dispatcher.start() + handler.set_dispatcher(dispatcher) + + mock_sender = AsyncMock() + handler._sender_func = mock_sender + + # Patch get_running_loop to verify it's called + with patch("meshcore.commands.base.asyncio.get_running_loop", wraps=asyncio.get_running_loop) as mock_grl: + # send with expected_events triggers the loop = asyncio.get_running_loop() path + result = await handler.send( + b"\x01", + expected_events=[EventType.OK], + timeout=0.05, + ) + mock_grl.assert_called() + + await dispatcher.stop() + + asyncio.run(_run()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py index 44dff6f..02516e1 100644 --- a/tests/unit/test_commands.py +++ b/tests/unit/test_commands.py @@ -28,6 +28,10 @@ def mock_dispatcher(): sub.unsubscribe = MagicMock() dispatcher._last_subscribe_handler = handler dispatcher._last_subscribe_event_type = event_type + # Immediately resolve the future so send() doesn't block + asyncio.get_event_loop().call_soon( + handler, Event(event_type, {}) + ) return sub dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) @@ -80,6 +84,13 @@ async def test_send_with_event(command_handler, mock_connection, mock_dispatcher async def test_send_timeout(command_handler, mock_connection, mock_dispatcher): + # Override to NOT resolve events, so we can test the timeout path + def non_resolving_subscribe(event_type, handler, attribute_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + return sub + mock_dispatcher.subscribe = MagicMock(side_effect=non_resolving_subscribe) + result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1) assert result.type == EventType.ERROR assert result.payload == {"reason": "no_event_received"} diff --git a/tests/unit/test_connection_manager.py b/tests/unit/test_connection_manager.py new file mode 100644 index 0000000..d92cf06 --- /dev/null +++ b/tests/unit/test_connection_manager.py @@ -0,0 +1,294 @@ +"""Tests for reconnect-path fixes.""" + +import asyncio + +import pytest + +from meshcore.connection_manager import ConnectionManager +from meshcore.events import Event, EventDispatcher, EventType + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class FakeConnection: + """Minimal stub that satisfies ConnectionProtocol.""" + + def __init__(self, connect_results=None): + """ + Args: + connect_results: iterator of return values for successive + connect() calls. ``None`` means soft failure; a string + means success; raising is also supported via sentinel. + """ + self._connect_results = list(connect_results or ["ok"]) + self._call_index = 0 + self.reader = None + + async def connect(self): + if self._call_index < len(self._connect_results): + result = self._connect_results[self._call_index] + self._call_index += 1 + else: + result = self._connect_results[-1] + if isinstance(result, Exception): + raise result + return result + + async def disconnect(self): + pass + + async def send(self, data): + pass + + def set_reader(self, reader): + self.reader = reader + + +class RaisingConnection(FakeConnection): + """Connection that raises on every connect() attempt.""" + + def __init__(self, exc=None): + super().__init__() + self._exc = exc or ConnectionError("boom") + + async def connect(self): + raise self._exc + + +class _EventCollector: + """Subscribes to all events and records them.""" + + def __init__(self, dispatcher: EventDispatcher): + self.events: list[Event] = [] + dispatcher.subscribe(None, self._on_event) + + async def _on_event(self, event: Event): + self.events.append(event) + + +# --------------------------------------------------------------------------- +# TCP connect() should return a plain value, not an asyncio.Future +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tcp_connect_returns_plain_string(): + """TCPConnection.connect() returns self.host (a plain string), not an + asyncio.Future. We test indirectly via ConnectionManager — the + CONNECTED event payload should contain a plain string, not a Future + object.""" + conn = FakeConnection(connect_results=["10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager(conn, dispatcher) + + result = await mgr.connect() + + assert result == "10.0.0.1" + # Give the dispatcher a moment to deliver the event + await asyncio.sleep(0.05) + connected_events = [e for e in collector.events if e.type == EventType.CONNECTED] + assert len(connected_events) == 1 + payload = connected_events[0].payload + assert payload["connection_info"] == "10.0.0.1" + # The payload value must NOT be an asyncio.Future + assert not isinstance(payload["connection_info"], asyncio.Future) + finally: + await dispatcher.stop() + + +# --------------------------------------------------------------------------- +# Reconnect attempts must not compound (no tail-recursive create_task) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_reconnect_loop_does_not_compound(): + """_attempt_reconnect must use a single iterative loop. After + max_reconnect_attempts failures, exactly that many connect() calls + should have been made — no exponential fan-out from orphaned tasks.""" + # All attempts fail (return None) + conn = FakeConnection(connect_results=[None, None, None, None]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager( + conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=3, + ) + mgr._is_connected = True # simulate a live connection + + await mgr.handle_disconnect("test_disconnect") + # Wait for the reconnect loop to exhaust all attempts + # (3 attempts × 1s sleep each, but we can just await the task) + if mgr._reconnect_task: + await mgr._reconnect_task + + # Exactly 3 connect() calls should have been made + assert conn._call_index == 3 + + # A DISCONNECTED event with max_attempts_exceeded should have fired + await asyncio.sleep(0.05) + disconnected = [e for e in collector.events if e.type == EventType.DISCONNECTED] + assert len(disconnected) == 1 + assert disconnected[0].payload.get("max_attempts_exceeded") is True + finally: + await dispatcher.stop() + + +@pytest.mark.asyncio +async def test_disconnect_cancels_reconnect_loop(): + """disconnect() during an active reconnect loop must cancel the + single task cleanly — no orphaned tasks left running.""" + # Simulate a connection that always fails (returns None), giving us + # time to call disconnect() mid-loop. + conn = FakeConnection(connect_results=[None, None, None, None, None]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + mgr = ConnectionManager( + conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=5, + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + + # Let the first attempt start (wait just past the 1s sleep) + await asyncio.sleep(1.2) + assert conn._call_index >= 1 # at least one attempt made + + # Now disconnect — should cancel the loop + await mgr.disconnect() + + assert mgr._reconnect_task is None + calls_at_cancel = conn._call_index + + # Wait a bit and confirm no more attempts happened + await asyncio.sleep(2) + assert conn._call_index == calls_at_cancel + finally: + await dispatcher.stop() + + +# --------------------------------------------------------------------------- +# reconnect_callback (send_appstart) is called after reconnect +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_reconnect_callback_called_after_reconnect(): + """When ConnectionManager reconnects successfully, the + reconnect_callback (e.g. send_appstart) must be invoked.""" + callback_called = [] + + async def fake_appstart(): + callback_called.append(True) + + # First connect() fails (None), second succeeds + conn = FakeConnection(connect_results=[None, "10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + mgr = ConnectionManager( + conn, dispatcher, + auto_reconnect=True, + max_reconnect_attempts=3, + reconnect_callback=fake_appstart, + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + if mgr._reconnect_task: + await mgr._reconnect_task + + assert len(callback_called) == 1 + finally: + await dispatcher.stop() + + +@pytest.mark.asyncio +async def test_reconnect_callback_failure_does_not_crash_loop(): + """If the reconnect_callback raises, the reconnect still counts as + successful (transport is up) — the callback failure is logged but + does not crash the loop or leave the manager in a broken state.""" + async def failing_callback(): + raise RuntimeError("appstart failed") + + # connect() succeeds on first attempt + conn = FakeConnection(connect_results=["10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager( + conn, dispatcher, + auto_reconnect=True, + max_reconnect_attempts=3, + reconnect_callback=failing_callback, + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + if mgr._reconnect_task: + await mgr._reconnect_task + + # Despite callback failure, CONNECTED event should have fired + await asyncio.sleep(0.05) + connected = [e for e in collector.events if e.type == EventType.CONNECTED] + assert len(connected) == 1 + assert mgr._is_connected is True + finally: + await dispatcher.stop() + + +# --------------------------------------------------------------------------- +# connect() returning None is a soft failure (BLE scan miss) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_connect_none_is_soft_failure(): + """When connect() returns None (e.g. BLE scan found no device), + ConnectionManager.connect() should NOT set _is_connected and should + NOT emit a CONNECTED event.""" + conn = FakeConnection(connect_results=[None]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + collector = _EventCollector(dispatcher) + mgr = ConnectionManager(conn, dispatcher) + + result = await mgr.connect() + + assert result is None + assert mgr._is_connected is False + await asyncio.sleep(0.05) + connected = [e for e in collector.events if e.type == EventType.CONNECTED] + assert len(connected) == 0 + finally: + await dispatcher.stop() + + +@pytest.mark.asyncio +async def test_no_reconnect_callback_is_noop(): + """When no reconnect_callback is provided (backwards compat for + direct ConnectionManager users), reconnect should still work.""" + conn = FakeConnection(connect_results=["10.0.0.1"]) + dispatcher = EventDispatcher() + await dispatcher.start() + try: + mgr = ConnectionManager( + conn, dispatcher, + auto_reconnect=True, + max_reconnect_attempts=3, + # No reconnect_callback — default None + ) + mgr._is_connected = True + + await mgr.handle_disconnect("test_disconnect") + if mgr._reconnect_task: + await mgr._reconnect_task + + assert mgr._is_connected is True + finally: + await dispatcher.stop() diff --git a/tests/unit/test_error_handling.py b/tests/unit/test_error_handling.py new file mode 100644 index 0000000..7f88e28 --- /dev/null +++ b/tests/unit/test_error_handling.py @@ -0,0 +1,236 @@ +"""Verification tests for error response handling fixes. + +The tests confirm that error responses are surfaced cleanly instead +of causing KeyError, TypeError, NameError, or silent fallthrough. +""" +import asyncio +import pytest +from unittest.mock import MagicMock, AsyncMock, patch + +from meshcore.commands import CommandHandler +from meshcore.events import EventType, Event, Subscription + +pytestmark = pytest.mark.asyncio + +VALID_PUBKEY_HEX = "0123456789abcdef" * 4 # 64 hex chars = 32 bytes + + +# ── Fixtures ─────────────────────────────────────────────────────── + +@pytest.fixture +def mock_connection(): + connection = MagicMock() + connection.send = AsyncMock() + return connection + + +@pytest.fixture +def mock_dispatcher(): + dispatcher = MagicMock() + dispatcher.wait_for_event = AsyncMock() + dispatcher.dispatch = AsyncMock() + + def fake_subscribe(event_type, handler, attribute_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + dispatcher._last_subscribe_handler = handler + dispatcher._last_subscribe_event_type = event_type + return sub + + dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) + return dispatcher + + +@pytest.fixture +def command_handler(mock_connection, mock_dispatcher): + handler = CommandHandler() + + async def sender(data): + await mock_connection.send(data) + + handler._sender_func = sender + handler.dispatcher = mock_dispatcher + return handler + + +def setup_error_response(mock_dispatcher): + """Configure dispatcher to return an ERROR event for any subscribe.""" + def fake_subscribe(evt_type, handler, attr_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + # Always fire ERROR regardless of which event type was subscribed + if evt_type == EventType.ERROR: + asyncio.get_event_loop().call_soon( + handler, Event(EventType.ERROR, {"reason": "test_error"}) + ) + return sub + + mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) + + +def setup_event_response(mock_dispatcher, event_type, payload): + """Configure dispatcher to return a specific event.""" + def fake_subscribe(evt_type, handler, attr_filters=None): + sub = MagicMock(spec=Subscription) + sub.unsubscribe = MagicMock() + if evt_type == event_type: + asyncio.get_event_loop().call_soon( + handler, Event(event_type, payload) + ) + return sub + + mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) + + +# ── Event.is_error() helper ────────────────────────────────── + +async def test_event_is_error_true(): + """is_error() returns True for ERROR events.""" + event = Event(EventType.ERROR, {"reason": "test"}) + assert event.is_error() is True + + +async def test_event_is_error_false(): + """is_error() returns False for non-ERROR events.""" + event = Event(EventType.OK, {}) + assert event.is_error() is False + event2 = Event(EventType.SELF_INFO, {"name": "test"}) + assert event2.is_error() is False + + +# ── send_msg_with_retry continues on ERROR ────────────── + +async def test_send_msg_with_retry_error_no_keyerror( + command_handler, mock_dispatcher +): + """send_msg_with_retry returns None (exhausted retries) on + persistent ERROR instead of raising KeyError on missing 'expected_ack'.""" + setup_error_response(mock_dispatcher) + + # Provide a mock contact so the path logic doesn't interfere + command_handler._get_contact_by_prefix = MagicMock(return_value=None) + + # max_attempts=2 so it retries once then gives up + result = await command_handler.send_msg_with_retry( + VALID_PUBKEY_HEX, "hello", max_attempts=2, timeout=0.1 + ) + + # Should return None (no ACK received) rather than raising KeyError + assert result is None + + +# ── send_appstart includes ERROR in expected events ────────── + +async def test_send_appstart_returns_error( + command_handler, mock_dispatcher +): + """send_appstart returns ERROR event instead of hanging on timeout.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.send_appstart() + + assert result.type == EventType.ERROR + assert result.is_error() is True + assert result.payload["reason"] == "test_error" + + +# ── device setters return ERROR from send_appstart ─────────── + +async def test_set_telemetry_mode_base_error( + command_handler, mock_dispatcher +): + """set_telemetry_mode_base returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_telemetry_mode_base(1) + + assert result.is_error() + assert result.payload["reason"] == "test_error" + + +async def test_set_telemetry_mode_loc_error( + command_handler, mock_dispatcher +): + """set_telemetry_mode_loc returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_telemetry_mode_loc(1) + + assert result.is_error() + + +async def test_set_telemetry_mode_env_error( + command_handler, mock_dispatcher +): + """set_telemetry_mode_env returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_telemetry_mode_env(1) + + assert result.is_error() + + +async def test_set_manual_add_contacts_error( + command_handler, mock_dispatcher +): + """set_manual_add_contacts returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_manual_add_contacts(True) + + assert result.is_error() + + +async def test_set_advert_loc_policy_error( + command_handler, mock_dispatcher +): + """set_advert_loc_policy returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_advert_loc_policy(1) + + assert result.is_error() + + +async def test_set_multi_acks_error( + command_handler, mock_dispatcher +): + """set_multi_acks returns ERROR instead of KeyError.""" + setup_error_response(mock_dispatcher) + + result = await command_handler.set_multi_acks(1) + + assert result.is_error() + + +# ── send_anon_req returns ERROR on contact not found ───────── + +async def test_send_anon_req_contact_not_found( + command_handler, mock_dispatcher +): + """send_anon_req returns ERROR event when contact prefix not found, + instead of raising TypeError on NoneType subscript.""" + command_handler._get_contact_by_prefix = MagicMock(return_value=None) + + result = await command_handler.send_anon_req( + VALID_PUBKEY_HEX, MagicMock(value=1) + ) + + assert result.is_error() + assert result.payload["reason"] == "contact_not_found" + + +# ── send_trace handles unknown path_hash_len without NameError ── + +async def test_send_trace_unknown_path_hash_len( + command_handler, mock_connection, mock_dispatcher +): + """send_trace with a path whose segments don't match any known + path_hash_len returns ERROR cleanly instead of NameError on 'e'.""" + # 5-char hex segments → path_hash_len = 2.5 → doesn't match 1,2,4,8 + result = await command_handler.send_trace( + auth_code=0, tag=1, flags=None, path="abcde" + ) + + assert result.is_error() + assert result.payload["reason"] == "invalid_path_format" diff --git a/tests/unit/test_protocol_surface_gaps.py b/tests/unit/test_protocol_surface_gaps.py new file mode 100644 index 0000000..a490a81 --- /dev/null +++ b/tests/unit/test_protocol_surface_gaps.py @@ -0,0 +1,364 @@ +"""Verification tests for protocol surface gaps. + +Each test constructs a mock firmware frame and verifies the SDK dispatches +the correct EventType with the expected payload fields. +""" + +import asyncio +import struct +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from meshcore.events import Event, EventType, EventDispatcher +from meshcore.reader import MessageReader +from meshcore.packets import PacketType, CommandType + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_reader(): + """Create a MessageReader with a mock dispatcher that records dispatched events.""" + dispatcher = MagicMock(spec=EventDispatcher) + dispatched = [] + + async def _capture(event): + dispatched.append(event) + + dispatcher.dispatch = AsyncMock(side_effect=_capture) + reader = MessageReader(dispatcher) + return reader, dispatched + + +# --------------------------------------------------------------------------- +# CONTACT_DELETED handler +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_contact_deleted_dispatches_event(): + """A 33-byte CONTACT_DELETED frame dispatches EventType.CONTACT_DELETED.""" + reader, dispatched = _make_reader() + pubkey = bytes(range(32)) + frame = bytes([PacketType.CONTACT_DELETED.value]) + pubkey + assert len(frame) == 33 + + await reader.handle_rx(bytearray(frame)) + + assert len(dispatched) == 1 + evt = dispatched[0] + assert evt.type == EventType.CONTACT_DELETED + assert evt.payload["pubkey"] == pubkey.hex() + assert evt.attributes["pubkey"] == pubkey.hex() + + +@pytest.mark.asyncio +async def test_contact_deleted_short_frame_ignored(): + """A CONTACT_DELETED frame shorter than 33 bytes is silently dropped.""" + reader, dispatched = _make_reader() + # Only 10 bytes — too short + frame = bytes([PacketType.CONTACT_DELETED.value]) + b"\x00" * 9 + + await reader.handle_rx(bytearray(frame)) + + assert len(dispatched) == 0 + + +# --------------------------------------------------------------------------- +# CONTACTS_FULL handler + enum entry +# --------------------------------------------------------------------------- + +def test_contacts_full_enum_exists(): + """PacketType.CONTACTS_FULL == 0x90.""" + assert PacketType.CONTACTS_FULL.value == 0x90 + + +@pytest.mark.asyncio +async def test_contacts_full_dispatches_event(): + """A 1-byte CONTACTS_FULL push dispatches EventType.CONTACTS_FULL.""" + reader, dispatched = _make_reader() + frame = bytes([PacketType.CONTACTS_FULL.value]) + + await reader.handle_rx(bytearray(frame)) + + assert len(dispatched) == 1 + evt = dispatched[0] + assert evt.type == EventType.CONTACTS_FULL + assert evt.payload == {} + + +# --------------------------------------------------------------------------- +# TUNING_PARAMS handler +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tuning_params_dispatches_event(): + """A 9-byte TUNING_PARAMS frame dispatches with rx_delay and airtime_factor.""" + reader, dispatched = _make_reader() + rx_delay = 500 + airtime_factor = 200 + frame = ( + bytes([PacketType.TUNING_PARAMS.value]) + + rx_delay.to_bytes(4, "little") + + airtime_factor.to_bytes(4, "little") + ) + assert len(frame) == 9 + + await reader.handle_rx(bytearray(frame)) + + assert len(dispatched) == 1 + evt = dispatched[0] + assert evt.type == EventType.TUNING_PARAMS + assert evt.payload["rx_delay"] == 500 + assert evt.payload["airtime_factor"] == 200 + + +@pytest.mark.asyncio +async def test_tuning_params_short_frame_dispatches_error(): + """A TUNING_PARAMS frame shorter than 9 bytes dispatches ERROR.""" + reader, dispatched = _make_reader() + # Only 5 bytes — too short + frame = bytes([PacketType.TUNING_PARAMS.value]) + b"\x01\x00\x00\x00" + + await reader.handle_rx(bytearray(frame)) + + assert len(dispatched) == 1 + evt = dispatched[0] + assert evt.type == EventType.ERROR + assert evt.payload["reason"] == "invalid_frame_length" + + +# --------------------------------------------------------------------------- +# send_trace() one-byte pad +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_trace_empty_path_pads_to_11_bytes(): + """send_trace() with no path produces an 11-byte packet (not 10).""" + from meshcore.commands.messaging import MessagingCommands + + cmd = MessagingCommands.__new__(MessagingCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000}) + + cmd.send = mock_send + + await cmd.send_trace(auth_code=0, tag=1, flags=0, path=None) + + assert captured_data is not None + # cmd(1) + tag(4) + auth(4) + flags(1) + pad(1) = 11 + assert len(captured_data) == 11 + assert captured_data[-1] == 0x00 # The pad byte + + +@pytest.mark.asyncio +async def test_send_trace_with_path_no_padding(): + """send_trace() with a non-empty path does NOT add padding.""" + from meshcore.commands.messaging import MessagingCommands + + cmd = MessagingCommands.__new__(MessagingCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000}) + + cmd.send = mock_send + + # 2-byte path hash (flags=1 means hash_len=2) + await cmd.send_trace(auth_code=0, tag=1, flags=1, path=b"\xAA\xBB") + + assert captured_data is not None + # cmd(1) + tag(4) + auth(4) + flags(1) + path(2) = 12 — no pad needed + assert len(captured_data) == 12 + + +# --------------------------------------------------------------------------- +# Command wrapper: send_raw_data +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_send_raw_data_wrapper(): + """send_raw_data sends CMD 0x19 + payload.""" + from meshcore.commands.messaging import MessagingCommands + + cmd = MessagingCommands.__new__(MessagingCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000}) + + cmd.send = mock_send + + await cmd.send_raw_data(b"\xDE\xAD") + + assert captured_data is not None + assert captured_data[0] == 0x19 # CMD_SEND_RAW_DATA + assert captured_data[1:] == b"\xDE\xAD" + + +# --------------------------------------------------------------------------- +# Command wrapper: has_connection +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_has_connection_wrapper(): + """has_connection sends CMD 0x1c.""" + from meshcore.commands.device import DeviceCommands + + cmd = DeviceCommands.__new__(DeviceCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.OK, {"value": 1}) + + cmd.send = mock_send + + await cmd.has_connection() + + assert captured_data is not None + assert captured_data == b"\x1c" + + +# --------------------------------------------------------------------------- +# Command wrapper: get_tuning +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_get_tuning_wrapper(): + """get_tuning sends CMD 0x2b (GET_TUNING_PARAMS = 43).""" + from meshcore.commands.device import DeviceCommands + + cmd = DeviceCommands.__new__(DeviceCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.TUNING_PARAMS, {"rx_delay": 500, "airtime_factor": 200}) + + cmd.send = mock_send + + result = await cmd.get_tuning() + + assert captured_data == b"\x2b" + assert result.type == EventType.TUNING_PARAMS + + +# --------------------------------------------------------------------------- +# Command wrapper: get_contact_by_key +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_get_contact_by_key_wrapper(): + """get_contact_by_key sends CMD 0x1e + 32-byte pubkey.""" + from meshcore.commands.contact import ContactCommands + + cmd = ContactCommands.__new__(ContactCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.NEXT_CONTACT, {"public_key": "ab" * 32}) + + cmd.send = mock_send + + pubkey = bytes(range(32)) + await cmd.get_contact_by_key(pubkey) + + assert captured_data is not None + assert captured_data[0] == 0x1E # CMD_GET_CONTACT_BY_KEY + assert captured_data[1:] == pubkey + + +# --------------------------------------------------------------------------- +# Command wrapper: factory_reset (two-step) +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_factory_reset_two_step(): + """factory_reset requires a token from request_factory_reset.""" + from meshcore.commands.device import DeviceCommands + + cmd = DeviceCommands.__new__(DeviceCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.OK, {}) + + cmd.send = mock_send + + # Step 1: request token + token = await cmd.request_factory_reset() + assert isinstance(token, str) + assert len(token) == 16 # hex-encoded 8 bytes + + # Step 2: confirm with wrong token fails + with pytest.raises(ValueError, match="Invalid or expired"): + await cmd.confirm_factory_reset("wrong_token") + + # Step 2: confirm with correct token succeeds + await cmd.confirm_factory_reset(token) + assert captured_data == b"\x33" # CMD_FACTORY_RESET + + +@pytest.mark.asyncio +async def test_factory_reset_without_request_fails(): + """confirm_factory_reset without request_factory_reset raises ValueError.""" + from meshcore.commands.device import DeviceCommands + + cmd = DeviceCommands.__new__(DeviceCommands) + + with pytest.raises(ValueError, match="Invalid or expired"): + await cmd.confirm_factory_reset("any_token") + + +# --------------------------------------------------------------------------- +# GET_STATS enum entry +# --------------------------------------------------------------------------- + +def test_get_stats_enum_exists(): + """CommandType.GET_STATS == 56.""" + assert CommandType.GET_STATS.value == 56 + + +@pytest.mark.asyncio +async def test_get_stats_core_uses_enum(): + """get_stats_core sends CommandType.GET_STATS.value (0x38) + 0x00.""" + from meshcore.commands.device import DeviceCommands + + cmd = DeviceCommands.__new__(DeviceCommands) + + captured_data = None + + async def mock_send(data, expected_events, timeout=None): + nonlocal captured_data + captured_data = bytes(data) + return Event(EventType.STATS_CORE, {}) + + cmd.send = mock_send + + await cmd.get_stats_core() + + assert captured_data is not None + assert captured_data[0] == CommandType.GET_STATS.value # 0x38 = 56 + assert captured_data[1] == 0x00 diff --git a/tests/unit/test_transport_symmetry.py b/tests/unit/test_transport_symmetry.py new file mode 100644 index 0000000..1882f3d --- /dev/null +++ b/tests/unit/test_transport_symmetry.py @@ -0,0 +1,238 @@ +""" +Verification tests for transport symmetry fixes. + +Covers: send symmetry across transports, serial disconnect callback on +transport-lost, serial connect timeout, oversize-frame return, BLE +disconnect-callback re-registration, BLE pairing failure re-raise, +TCP counter per frame not per segment. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from meshcore.tcp_cx import TCPConnection +from meshcore.serial_cx import SerialConnection + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class RecordingReader: + """Minimal reader mock that records dispatched frames.""" + def __init__(self): + self.frames = [] + + async def handle_rx(self, data): + self.frames.append(bytes(data)) + + +# --------------------------------------------------------------------------- +# TCP send() wraps transport.write in try/except +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tcp_send_write_error_fires_disconnect(): + """TCP: OSError during transport.write fires _disconnect_callback.""" + cx = TCPConnection("127.0.0.1", 5000) + cb = AsyncMock() + cx.set_disconnect_callback(cb) + + mock_transport = MagicMock() + mock_transport.write.side_effect = OSError("Broken pipe") + cx.transport = mock_transport + cx._send_count = 0 + cx._receive_count = 0 + + await cx.send(b"\x01\x02\x03") + + cb.assert_awaited_once() + reason = cb.call_args[0][0] + assert "tcp_write_failed" in reason + + +# --------------------------------------------------------------------------- +# Serial send() fires disconnect on transport-lost and write error +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_serial_send_no_transport_fires_disconnect(): + """Serial: send() on None transport fires _disconnect_callback .""" + cx = SerialConnection("/dev/null", 115200) + cb = AsyncMock() + cx.set_disconnect_callback(cb) + cx.transport = None + + await cx.send(b"\x01") + + cb.assert_awaited_once() + reason = cb.call_args[0][0] + assert reason == "serial_transport_lost" + + +@pytest.mark.asyncio +async def test_serial_send_write_error_fires_disconnect(): + """Serial: OSError during transport.write fires _disconnect_callback.""" + cx = SerialConnection("/dev/null", 115200) + cb = AsyncMock() + cx.set_disconnect_callback(cb) + + mock_transport = MagicMock() + mock_transport.write.side_effect = OSError("Device not configured") + cx.transport = mock_transport + + await cx.send(b"\x01") + + cb.assert_awaited_once() + reason = cb.call_args[0][0] + assert "serial_write_failed" in reason + + +# --------------------------------------------------------------------------- +# BLE send() fires disconnect on transport-lost and write error +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_ble_send_no_client_fires_disconnect(): + """BLE: send() with no client fires _disconnect_callback.""" + # Can't import BLEConnection directly if bleak isn't installed, + # so test via dynamic import with a guard. + try: + from meshcore.ble_cx import BLEConnection + except ImportError: + pytest.skip("bleak not installed") + + # BLEConnection.__init__ checks BLEAK_AVAILABLE; patch it + with patch("meshcore.ble_cx.BLEAK_AVAILABLE", True), \ + patch("meshcore.ble_cx.BleakClient", MagicMock()): + cx = BLEConnection.__new__(BLEConnection) + cx.client = None + cx._user_provided_client = None + cx._user_provided_address = None + cx._user_provided_device = None + cx.address = None + cx.device = None + cx.pin = None + cx.rx_char = None + cb = AsyncMock() + cx._disconnect_callback = cb + + result = await cx.send(b"\x01") + + assert result is False + cb.assert_awaited_once() + reason = cb.call_args[0][0] + assert reason == "ble_transport_lost" + + +# --------------------------------------------------------------------------- +# Serial connect() times out if connection_made never fires +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_serial_connect_timeout(): + """Serial: connect() raises TimeoutError if connection_made never fires.""" + cx = SerialConnection("/dev/null", 115200) + + # Mock create_serial_connection to do nothing (never fires connection_made) + async def mock_create(*args, **kwargs): + return (MagicMock(), MagicMock()) + + with patch("meshcore.serial_cx.serial_asyncio.create_serial_connection", + side_effect=mock_create): + with pytest.raises(asyncio.TimeoutError): + await cx.connect(timeout=0.1) + + +# --------------------------------------------------------------------------- +# Oversize frame resets state and returns without dispatch +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tcp_oversize_frame_empty_data_returns(): + """TCP: oversize header with no trailing data returns without dispatch.""" + cx = TCPConnection("127.0.0.1", 5000) + reader = RecordingReader() + cx.set_reader(reader) + + # Build a frame header with size > 300 and no payload data after header + # Header: 0x3e + 2-byte LE size (e.g. 500 = 0x01F4) + header = b"\x3e" + (500).to_bytes(2, "little") + cx.handle_rx(header) + await asyncio.sleep(0) + + # No frames should be dispatched, and state should be reset + assert reader.frames == [] + assert cx.header == b"" + assert cx.inframe == b"" + assert cx.frame_expected_size == 0 + + +@pytest.mark.asyncio +async def test_serial_oversize_frame_empty_data_returns(): + """Serial: oversize header with no trailing data returns without dispatch.""" + cx = SerialConnection("/dev/null", 115200) + reader = RecordingReader() + cx.set_reader(reader) + + header = b"\x3e" + (500).to_bytes(2, "little") + cx.handle_rx(header) + await asyncio.sleep(0) + + assert reader.frames == [] + assert cx.header == b"" + assert cx.inframe == b"" + assert cx.frame_expected_size == 0 + + +# --------------------------------------------------------------------------- +# TCP receive counter increments per MeshCore frame, not per TCP segment +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_tcp_receive_count_per_frame_not_per_segment(): + """TCP: _receive_count increments per completed frame, not per data_received call.""" + cx = TCPConnection("127.0.0.1", 5000) + reader = RecordingReader() + cx.set_reader(reader) + cx._receive_count = 0 + + # Build a 4-byte payload frame + payload = b"\xAA\xBB\xCC\xDD" + frame = b"\x3e" + len(payload).to_bytes(2, "little") + payload + + # Split the frame into 3 TCP segments (simulating fragmentation) + protocol = TCPConnection.MCClientProtocol(cx) + protocol.data_received(frame[:2]) # partial header + protocol.data_received(frame[2:5]) # rest of header + 2 bytes payload + protocol.data_received(frame[5:]) # remaining payload + + await asyncio.sleep(0) + + # 3 data_received calls but only 1 completed frame + assert cx._receive_count == 1 + assert reader.frames == [payload] + + +@pytest.mark.asyncio +async def test_tcp_multiple_frames_count_correctly(): + """TCP: two complete frames in separate segments → _receive_count == 2.""" + cx = TCPConnection("127.0.0.1", 5000) + reader = RecordingReader() + cx.set_reader(reader) + cx._receive_count = 0 + + payload1 = b"\x01\x02" + frame1 = b"\x3e" + len(payload1).to_bytes(2, "little") + payload1 + payload2 = b"\x03\x04\x05" + frame2 = b"\x3e" + len(payload2).to_bytes(2, "little") + payload2 + + protocol = TCPConnection.MCClientProtocol(cx) + protocol.data_received(frame1) + protocol.data_received(frame2) + await asyncio.sleep(0) + + assert cx._receive_count == 2 + assert reader.frames == [payload1, payload2]