Merge branch 'main' into fix/standalone-bugs-and-cleanup

This commit is contained in:
fdlamotte
2026-04-25 15:21:16 +02:00
committed by GitHub
21 changed files with 2926 additions and 968 deletions

View File

@@ -51,6 +51,14 @@ class BLEConnection:
self.pin = pin self.pin = pin
self.rx_char = None self.rx_char = None
self._disconnect_callback = None self._disconnect_callback = None
self._background_tasks: set[asyncio.Task] = set()
def _spawn_background(self, coro) -> asyncio.Task:
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
task = asyncio.create_task(coro)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return task
async def connect(self): async def connect(self):
""" """
@@ -116,9 +124,12 @@ class BLEConnection:
await self.client.pair() await self.client.pair()
logger.info("BLE pairing successful") logger.info("BLE pairing successful")
except Exception as e: except Exception as e:
logger.warning(f"BLE pairing failed: {e}") logger.error(f"BLE pairing failed: {e}")
# Don't fail the connection if pairing fails, as the device # A failed pairing leaves the transport in a half-usable
# might already be paired or not require pairing # state — re-raise so the caller gets a clean failure
# instead of a silently degraded connection.
await self.client.disconnect()
raise
except BleakDeviceNotFoundError: except BleakDeviceNotFoundError:
return None return None
@@ -154,8 +165,19 @@ class BLEConnection:
self.client = self._user_provided_client self.client = self._user_provided_client
self.device = self._user_provided_device self.device = self._user_provided_device
# Re-register disconnect callback on the reset client so subsequent
# disconnects after a reconnect cycle are still detected.
if self.client is not None and hasattr(self.client, 'set_disconnected_callback'):
try:
self.client.set_disconnected_callback(self.handle_disconnect)
except Exception:
# set_disconnected_callback may not be available on all bleak
# versions; the next connect() call will re-create the client
# with the callback anyway.
pass
if self._disconnect_callback: if self._disconnect_callback:
asyncio.create_task(self._disconnect_callback("ble_disconnect")) self._spawn_background(self._disconnect_callback("ble_disconnect"))
def set_disconnect_callback(self, callback): def set_disconnect_callback(self, callback):
"""Set callback to handle disconnections.""" """Set callback to handle disconnections."""
@@ -166,16 +188,24 @@ class BLEConnection:
def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray): def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray):
if self.reader is not None: if self.reader is not None:
asyncio.create_task(self.reader.handle_rx(data)) self._spawn_background(self.reader.handle_rx(data))
async def send(self, data): async def send(self, data):
if not self.client: if not self.client:
logger.error("Client is not connected") logger.error("Client is not connected")
if self._disconnect_callback:
await self._disconnect_callback("ble_transport_lost")
return False return False
if not self.rx_char: if not self.rx_char:
logger.error("RX characteristic not found") logger.error("RX characteristic not found")
return False return False
await self.client.write_gatt_char(self.rx_char, bytes(data), response=True) try:
await self.client.write_gatt_char(self.rx_char, bytes(data), response=True)
except Exception as exc:
logger.warning(f"BLE write failed: {exc}")
if self._disconnect_callback:
await self._disconnect_callback(f"ble_write_failed: {exc}")
return False
async def disconnect(self): async def disconnect(self):
"""Disconnect from the BLE device.""" """Disconnect from the BLE device."""

View File

@@ -58,17 +58,32 @@ def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes
class CommandHandlerBase: class CommandHandlerBase:
"""Base class for command handlers.
.. note::
The internal ``asyncio.Lock`` is created lazily on first access
so that it binds to the correct running event loop (required for
Python 3.9/3.10 compatibility).
"""
DEFAULT_TIMEOUT = 15.0 DEFAULT_TIMEOUT = 15.0
def __init__(self, default_timeout: Optional[float] = None): def __init__(self, default_timeout: Optional[float] = None):
self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None
self._reader: Optional[MessageReader] = None self._reader: Optional[MessageReader] = None
self.dispatcher: Optional[EventDispatcher] = None self.dispatcher: Optional[EventDispatcher] = None
self._mesh_request_lock = asyncio.Lock() self.__mesh_request_lock: Optional[asyncio.Lock] = None
self.default_timeout = ( self.default_timeout = (
default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT
) )
@property
def _mesh_request_lock(self) -> asyncio.Lock:
"""Lazy-init lock so it binds to the running loop, not import-time."""
if self.__mesh_request_lock is None:
self.__mesh_request_lock = asyncio.Lock()
return self.__mesh_request_lock
def set_connection(self, connection: Any) -> None: def set_connection(self, connection: Any) -> None:
async def sender(data: bytes) -> None: async def sender(data: bytes) -> None:
await connection.send(data) await connection.send(data)
@@ -90,6 +105,14 @@ class CommandHandlerBase:
expected_events: Optional[Union[EventType, List[EventType]]] = None, expected_events: Optional[Union[EventType, List[EventType]]] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
) -> Event: ) -> Event:
"""Wait for the first of *expected_events* to arrive.
Returns the first matched ``Event``. When ``EventType.ERROR`` is
among the expected types, the caller **must** check
``result.is_error()`` before accessing command-specific payload
keys — an ERROR payload is ``{"reason": "..."}`` and will
``KeyError`` on any other key.
"""
try: try:
# Convert single event to list if needed # Convert single event to list if needed
if not isinstance(expected_events, list): if not isinstance(expected_events, list):
@@ -129,9 +152,6 @@ class CommandHandlerBase:
logger.debug(f"Command error: {e}") logger.debug(f"Command error: {e}")
return Event(EventType.ERROR, {"error": str(e)}) return Event(EventType.ERROR, {"error": str(e)})
return Event(EventType.ERROR, {})
async def send( async def send(
self, self,
data: bytes, data: bytes,
@@ -151,7 +171,14 @@ class CommandHandlerBase:
timeout: Timeout in seconds, or None to use default_timeout timeout: Timeout in seconds, or None to use default_timeout
Returns: Returns:
Event: The full event object that was received in response to the command Event: The full event object that was received in response to
the command.
Important:
When ``EventType.ERROR`` is included in *expected_events*, the
returned event may be an error response. Callers **must**
check ``result.is_error()`` before accessing command-specific
payload keys to avoid ``KeyError``.
""" """
if not self.dispatcher: if not self.dispatcher:
raise RuntimeError("Dispatcher not set, cannot send commands") raise RuntimeError("Dispatcher not set, cannot send commands")
@@ -170,7 +197,7 @@ class CommandHandlerBase:
futures: List[asyncio.Future] = [] futures: List[asyncio.Future] = []
subscriptions = [] subscriptions = []
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
for event_type in expected_events: for event_type in expected_events:
future = loop.create_future() future = loop.create_future()
@@ -279,6 +306,7 @@ class CommandHandlerBase:
contact = self._get_contact_by_prefix(dst_bytes.hex()) # need a contact for return path contact = self._get_contact_by_prefix(dst_bytes.hex()) # need a contact for return path
if contact is None: if contact is None:
logger.error("No contact found") logger.error("No contact found")
return Event(EventType.ERROR, {"reason": "contact_not_found"})
zero_hop = False zero_hop = False
if contact["out_path_len"] == -1: if contact["out_path_len"] == -1:

View File

@@ -191,6 +191,24 @@ class ContactCommands(CommandHandlerBase):
data = b"\x3B" data = b"\x3B"
return await self.send(data, [EventType.AUTOADD_CONFIG, EventType.ERROR]) return await self.send(data, [EventType.AUTOADD_CONFIG, EventType.ERROR])
async def get_contact_by_key(self, pubkey: bytes) -> Event:
"""N09: Retrieve a single contact by its public key (CMD 30).
Args:
pubkey: 32-byte public key of the contact.
Returns:
Event with the contact data (same format as CONTACT/NEXT_CONTACT),
or ERROR if not found.
"""
if not isinstance(pubkey, (bytes, bytearray)):
raise TypeError("pubkey must be bytes-like")
# Truncate or pad to 32 bytes
key_bytes = bytes(pubkey[:32])
logger.debug(f"Getting contact by key: {key_bytes.hex()}")
data = b"\x1e" + key_bytes
return await self.send(data, [EventType.NEXT_CONTACT, EventType.ERROR])
async def get_advert_path(self, key: DestinationType) -> Event: async def get_advert_path(self, key: DestinationType) -> Event:
key_bytes = _validate_destination(key, prefix_length=32) key_bytes = _validate_destination(key, prefix_length=32)
logger.debug(f"getting advert path for: {key} {key_bytes.hex()}") logger.debug(f"getting advert path for: {key} {key_bytes.hex()}")

View File

@@ -4,6 +4,7 @@ from hashlib import sha256
from typing import Optional from typing import Optional
from ..events import Event, EventType from ..events import Event, EventType
from ..packets import CommandType
from .base import CommandHandlerBase, DestinationType, _validate_destination from .base import CommandHandlerBase, DestinationType, _validate_destination
logger = logging.getLogger("meshcore") logger = logging.getLogger("meshcore")
@@ -13,7 +14,7 @@ class DeviceCommands(CommandHandlerBase):
async def send_appstart(self) -> Event: async def send_appstart(self) -> Event:
logger.debug("Sending appstart command") logger.debug("Sending appstart command")
b1 = bytearray(b"\x01\x03 mccli") b1 = bytearray(b"\x01\x03 mccli")
return await self.send(b1, [EventType.SELF_INFO]) return await self.send(b1, [EventType.SELF_INFO, EventType.ERROR])
async def send_device_query(self) -> Event: async def send_device_query(self) -> Event:
logger.debug("Sending device query command") logger.debug("Sending device query command")
@@ -129,32 +130,50 @@ class DeviceCommands(CommandHandlerBase):
return await self.send(data, [EventType.OK, EventType.ERROR]) return await self.send(data, [EventType.OK, EventType.ERROR])
async def set_telemetry_mode_base(self, telemetry_mode_base: int) -> Event: async def set_telemetry_mode_base(self, telemetry_mode_base: int) -> Event:
infos = (await self.send_appstart()).payload result = await self.send_appstart()
if result.is_error():
return result
infos = result.payload
infos["telemetry_mode_base"] = telemetry_mode_base infos["telemetry_mode_base"] = telemetry_mode_base
return await self.set_other_params_from_infos(infos) return await self.set_other_params_from_infos(infos)
async def set_telemetry_mode_loc(self, telemetry_mode_loc: int) -> Event: async def set_telemetry_mode_loc(self, telemetry_mode_loc: int) -> Event:
infos = (await self.send_appstart()).payload result = await self.send_appstart()
if result.is_error():
return result
infos = result.payload
infos["telemetry_mode_loc"] = telemetry_mode_loc infos["telemetry_mode_loc"] = telemetry_mode_loc
return await self.set_other_params_from_infos(infos) return await self.set_other_params_from_infos(infos)
async def set_telemetry_mode_env(self, telemetry_mode_env: int) -> Event: async def set_telemetry_mode_env(self, telemetry_mode_env: int) -> Event:
infos = (await self.send_appstart()).payload result = await self.send_appstart()
if result.is_error():
return result
infos = result.payload
infos["telemetry_mode_env"] = telemetry_mode_env infos["telemetry_mode_env"] = telemetry_mode_env
return await self.set_other_params_from_infos(infos) return await self.set_other_params_from_infos(infos)
async def set_manual_add_contacts(self, manual_add_contacts: bool) -> Event: async def set_manual_add_contacts(self, manual_add_contacts: bool) -> Event:
infos = (await self.send_appstart()).payload result = await self.send_appstart()
if result.is_error():
return result
infos = result.payload
infos["manual_add_contacts"] = manual_add_contacts infos["manual_add_contacts"] = manual_add_contacts
return await self.set_other_params_from_infos(infos) return await self.set_other_params_from_infos(infos)
async def set_advert_loc_policy(self, advert_loc_policy: int) -> Event: async def set_advert_loc_policy(self, advert_loc_policy: int) -> Event:
infos = (await self.send_appstart()).payload result = await self.send_appstart()
if result.is_error():
return result
infos = result.payload
infos["adv_loc_policy"] = advert_loc_policy infos["adv_loc_policy"] = advert_loc_policy
return await self.set_other_params_from_infos(infos) return await self.set_other_params_from_infos(infos)
async def set_multi_acks(self, multi_acks: int) -> Event: async def set_multi_acks(self, multi_acks: int) -> Event:
infos = (await self.send_appstart()).payload result = await self.send_appstart()
if result.is_error():
return result
infos = result.payload
infos["multi_acks"] = multi_acks infos["multi_acks"] = multi_acks
return await self.set_other_params_from_infos(infos) return await self.set_other_params_from_infos(infos)
@@ -273,20 +292,89 @@ class DeviceCommands(CommandHandlerBase):
return await self.sign_finish(timeout=timeout, data_size=len(data)) return await self.sign_finish(timeout=timeout, data_size=len(data))
async def has_connection(self) -> Event:
"""N09: Check if the device has an active connection (CMD 28).
Returns:
Event with a 1-byte response indicating connection status,
or ERROR.
"""
logger.debug("Checking device connection status")
return await self.send(b"\x1c", [EventType.OK, EventType.ERROR])
async def get_tuning(self) -> Event:
"""N03/N09: Request current tuning parameters (CMD_GET_TUNING_PARAMS = 43).
Firmware responds with RESP_CODE_TUNING_PARAMS (23): 9 bytes containing
rx_delay (4 bytes LE) and airtime_factor (4 bytes LE).
Returns:
Event of type TUNING_PARAMS with rx_delay and airtime_factor,
or ERROR.
"""
logger.debug("Getting tuning parameters")
return await self.send(b"\x2b", [EventType.TUNING_PARAMS, EventType.ERROR])
async def request_factory_reset(self) -> str:
"""N09: Request a factory reset token (step 1 of 2).
This method returns a confirmation token string. Pass it to
``confirm_factory_reset(token)`` to actually execute the reset.
The two-step pattern is a Python-side safety measure; the firmware
itself has no token verification.
Returns:
A confirmation token string to pass to confirm_factory_reset().
"""
import secrets
token = secrets.token_hex(8)
logger.warning(
"Factory reset requested. Call confirm_factory_reset('%s') to proceed. "
"This will ERASE ALL DATA on the device.", token
)
# Store the token on the instance for validation
self._factory_reset_token = token
return token
async def confirm_factory_reset(self, token: str) -> Event:
"""N09: Execute factory reset after token confirmation (step 2 of 2).
Args:
token: The token returned by request_factory_reset().
Returns:
Event with OK or ERROR.
Raises:
ValueError: If the token does not match.
"""
expected = getattr(self, "_factory_reset_token", None)
if expected is None or token != expected:
raise ValueError(
"Invalid or expired factory reset token. "
"Call request_factory_reset() first."
)
self._factory_reset_token = None # Consume the token
logger.warning("Executing factory reset — all device data will be erased")
return await self.send(b"\x33", [EventType.OK, EventType.ERROR])
async def get_stats_core(self) -> Event: async def get_stats_core(self) -> Event:
logger.debug("Getting core statistics") logger.debug("Getting core statistics")
# CMD_GET_STATS (56) + STATS_TYPE_CORE (0) # R04: Use CommandType enum instead of literal bytes
return await self.send(b"\x38\x00", [EventType.STATS_CORE, EventType.ERROR]) cmd = bytes([CommandType.GET_STATS.value, 0x00]) # GET_STATS + STATS_TYPE_CORE
return await self.send(cmd, [EventType.STATS_CORE, EventType.ERROR])
async def get_stats_radio(self) -> Event: async def get_stats_radio(self) -> Event:
logger.debug("Getting radio statistics") logger.debug("Getting radio statistics")
# CMD_GET_STATS (56) + STATS_TYPE_RADIO (1) # R04: Use CommandType enum instead of literal bytes
return await self.send(b"\x38\x01", [EventType.STATS_RADIO, EventType.ERROR]) cmd = bytes([CommandType.GET_STATS.value, 0x01]) # GET_STATS + STATS_TYPE_RADIO
return await self.send(cmd, [EventType.STATS_RADIO, EventType.ERROR])
async def get_stats_packets(self) -> Event: async def get_stats_packets(self) -> Event:
logger.debug("Getting packet statistics") logger.debug("Getting packet statistics")
# CMD_GET_STATS (56) + STATS_TYPE_PACKETS (2) # R04: Use CommandType enum instead of literal bytes
return await self.send(b"\x38\x02", [EventType.STATS_PACKETS, EventType.ERROR]) cmd = bytes([CommandType.GET_STATS.value, 0x02]) # GET_STATS + STATS_TYPE_PACKETS
return await self.send(cmd, [EventType.STATS_PACKETS, EventType.ERROR])
async def get_allowed_repeat_freq(self) -> Event: async def get_allowed_repeat_freq(self) -> Event:
logger.debug("Getting allowed repeat freqs") logger.debug("Getting allowed repeat freqs")

View File

@@ -144,8 +144,12 @@ class MessagingCommands(CommandHandlerBase):
logger.info(f"Retry sending msg: {attempts + 1}") logger.info(f"Retry sending msg: {attempts + 1}")
result = await self.send_msg(dst, msg, timestamp, attempt=attempts) result = await self.send_msg(dst, msg, timestamp, attempt=attempts)
if result.type == EventType.ERROR: if result.is_error():
logger.error(f"⚠️ Failed to send message: {result.payload}") logger.error(f"Failed to send message: {result.payload}")
attempts += 1
if flood:
flood_attempts += 1
continue
exp_ack = result.payload["expected_ack"].hex() exp_ack = result.payload["expected_ack"].hex()
timeout = result.payload["suggested_timeout"] / 1000 * 1.2 if timeout==0 else timeout timeout = result.payload["suggested_timeout"] / 1000 * 1.2 if timeout==0 else timeout
@@ -255,7 +259,7 @@ class MessagingCommands(CommandHandlerBase):
elif path_hash_len == 8 : elif path_hash_len == 8 :
flags = 3 flags = 3
else : else :
logger.error(f"Invalid path format: {e}") logger.error(f"Invalid path format: unknown path_hash_len {path_hash_len}")
return Event(EventType.ERROR, {"reason": "invalid_path_format"}) return Event(EventType.ERROR, {"reason": "invalid_path_format"})
else: else:
flags = 0 flags = 0
@@ -291,12 +295,34 @@ class MessagingCommands(CommandHandlerBase):
cmd_data.append(flags) cmd_data.append(flags)
cmd_data.extend(path_bytes) cmd_data.extend(path_bytes)
# N05: Firmware requires strict len > 10 (MyMesh.cpp:1620).
# When path is empty, cmd(1)+tag(4)+auth(4)+flags(1) = 10 bytes exactly,
# which is silently rejected. Pad with one zero byte to reach 11.
if len(cmd_data) <= 10:
cmd_data.append(0x00)
logger.debug( logger.debug(
f"Sending trace: tag={tag}, auth={auth_code}, flags={flags}, path={path_bytes.hex()}" f"Sending trace: tag={tag}, auth={auth_code}, flags={flags}, path={path_bytes.hex()}"
) )
return await self.send(cmd_data, [EventType.MSG_SENT, EventType.ERROR]) return await self.send(cmd_data, [EventType.MSG_SENT, EventType.ERROR])
async def send_raw_data(self, payload: bytes) -> Event:
"""N09: Send raw data via CMD_SEND_RAW_DATA (25).
Sends an arbitrary payload through the mesh network.
Args:
payload: Raw bytes to send.
Returns:
Event with MSG_SENT or ERROR.
"""
if not isinstance(payload, (bytes, bytearray)):
raise TypeError("payload must be bytes-like")
data = b"\x19" + bytes(payload)
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
async def set_flood_scope(self, scope): async def set_flood_scope(self, scope):
if scope is None: if scope is None:
logger.debug(f"Resetting scope") logger.debug(f"Resetting scope")

View File

@@ -4,14 +4,23 @@ Connection manager that orchestrates reconnection logic for any connection type.
import asyncio import asyncio
import logging import logging
from typing import Optional, Any, Callable, Protocol from typing import Optional, Any, Awaitable, Callable, Protocol
from .events import Event, EventType from .events import Event, EventType
logger = logging.getLogger("meshcore") logger = logging.getLogger("meshcore")
class ConnectionProtocol(Protocol): class ConnectionProtocol(Protocol):
"""Protocol defining the interface that connection classes must implement.""" """Protocol defining the interface that connection classes must implement.
Return contract for connect():
- On success: return a truthy value (typically an address string)
that identifies the connection. This value is included in the
CONNECTED event payload as ``connection_info``.
- On failure: return ``None`` (soft failure — triggers a retry in
``_attempt_reconnect``) **or** raise an exception (hard failure —
also triggers a retry, logged as an error).
"""
async def connect(self) -> Optional[Any]: async def connect(self) -> Optional[Any]:
"""Connect and return connection info, or None if failed.""" """Connect and return connection info, or None if failed."""
@@ -39,11 +48,13 @@ class ConnectionManager:
event_dispatcher=None, event_dispatcher=None,
auto_reconnect: bool = False, auto_reconnect: bool = False,
max_reconnect_attempts: int = 3, max_reconnect_attempts: int = 3,
reconnect_callback: Optional[Callable[[], Awaitable[None]]] = None,
): ):
self.connection = connection self.connection = connection
self.event_dispatcher = event_dispatcher self.event_dispatcher = event_dispatcher
self.auto_reconnect = auto_reconnect self.auto_reconnect = auto_reconnect
self.max_reconnect_attempts = max_reconnect_attempts self.max_reconnect_attempts = max_reconnect_attempts
self._reconnect_callback = reconnect_callback
self._reconnect_attempts = 0 self._reconnect_attempts = 0
self._is_connected = False self._is_connected = False
@@ -109,45 +120,51 @@ class ConnectionManager:
) )
async def _attempt_reconnect(self): async def _attempt_reconnect(self):
"""Attempt to reconnect with flat delay.""" """Attempt to reconnect using an iterative loop.
logger.debug(
f"Attempting reconnection ({self._reconnect_attempts + 1}/{self.max_reconnect_attempts})"
)
self._reconnect_attempts += 1
# Flat 1 second delay for all attempts Runs as a single persistent task for the entire reconnect session.
await asyncio.sleep(1) Previous implementation used tail-recursion via create_task which
orphaned the running task reference — disconnect() could only cancel
the newest pointer, leaving earlier attempts in flight (F03).
"""
while self._reconnect_attempts < self.max_reconnect_attempts:
self._reconnect_attempts += 1
logger.debug(
f"Attempting reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})"
)
# Flat 1 second delay for all attempts
await asyncio.sleep(1)
try:
result = await self.connection.connect()
if result is not None:
self._is_connected = True
self._reconnect_attempts = 0
# Invoke reconnect callback (e.g. send_appstart) if provided
if self._reconnect_callback is not None:
try:
await self._reconnect_callback()
except Exception as cb_err:
logger.warning(
f"Reconnect callback failed: {cb_err}"
)
try:
result = await self.connection.connect()
if result is not None:
self._is_connected = True
self._reconnect_attempts = 0
await self._emit_event(
EventType.CONNECTED,
{"connection_info": result, "reconnected": True},
)
logger.debug("Reconnected successfully")
else:
# Reconnection failed, try again if we haven't exceeded max attempts
if self._reconnect_attempts < self.max_reconnect_attempts:
self._reconnect_task = asyncio.create_task(
self._attempt_reconnect()
)
else:
await self._emit_event( await self._emit_event(
EventType.DISCONNECTED, EventType.CONNECTED,
{"reason": "reconnect_failed", "max_attempts_exceeded": True}, {"connection_info": result, "reconnected": True},
) )
except Exception as e: logger.debug("Reconnected successfully")
logger.debug(f"Reconnection attempt failed: {e}") return
if self._reconnect_attempts < self.max_reconnect_attempts: except Exception as e:
self._reconnect_task = asyncio.create_task(self._attempt_reconnect()) logger.debug(f"Reconnection attempt failed: {e}")
else:
await self._emit_event( # All attempts exhausted
EventType.DISCONNECTED, await self._emit_event(
{"reason": f"reconnect_error: {e}", "max_attempts_exceeded": True}, EventType.DISCONNECTED,
) {"reason": "reconnect_failed", "max_attempts_exceeded": True},
)
async def _emit_event(self, event_type: EventType, payload: dict): async def _emit_event(self, event_type: EventType, payload: dict):
"""Emit connection events if dispatcher is available.""" """Emit connection events if dispatcher is available."""

View File

@@ -49,6 +49,9 @@ class EventType(Enum):
PATH_RESPONSE = "path_response" PATH_RESPONSE = "path_response"
PRIVATE_KEY = "private_key" PRIVATE_KEY = "private_key"
DISABLED = "disabled" DISABLED = "disabled"
CONTACT_DELETED = "contact_deleted"
CONTACTS_FULL = "contacts_full"
TUNING_PARAMS = "tuning_params"
CONTROL_DATA = "control_data" CONTROL_DATA = "control_data"
DISCOVER_RESPONSE = "discover_response" DISCOVER_RESPONSE = "discover_response"
NEIGHBOURS_RESPONSE = "neighbours_response" NEIGHBOURS_RESPONSE = "neighbours_response"
@@ -104,6 +107,17 @@ class Event:
if kwargs: if kwargs:
self.attributes.update(kwargs) self.attributes.update(kwargs)
def is_error(self) -> bool:
"""Return True if this event represents an error response.
Callers that include ``EventType.ERROR`` in their expected-events
list **must** check ``result.is_error()`` (or ``result.type ==
EventType.ERROR``) before accessing keyed payload fields, because
an ERROR payload contains ``{"reason": "..."}`` — not the
command-specific keys the caller expects on the happy path.
"""
return self.type == EventType.ERROR
def clone(self): def clone(self):
""" """
Create a copy of the event. Create a copy of the event.
@@ -129,11 +143,28 @@ class Subscription:
class EventDispatcher: class EventDispatcher:
"""Event dispatch engine.
.. note::
``start()`` must be called before dispatching or processing events.
The internal ``asyncio.Queue`` is created lazily inside ``start()``
so that it binds to the correct running event loop (required for
Python 3.9/3.10 compatibility).
"""
def __init__(self): def __init__(self):
self.queue: asyncio.Queue[Event] = asyncio.Queue() self.queue: Optional[asyncio.Queue[Event]] = None
self.subscriptions: List[Subscription] = [] self.subscriptions: List[Subscription] = []
self.running = False self.running = False
self._task = None self._task = None
self._background_tasks: set[asyncio.Task] = set()
def _spawn_background(self, coro) -> asyncio.Task:
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
task = asyncio.create_task(coro)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return task
def subscribe( def subscribe(
self, self,
@@ -166,6 +197,10 @@ class EventDispatcher:
self.subscriptions.remove(subscription) self.subscriptions.remove(subscription)
async def dispatch(self, event: Event): async def dispatch(self, event: Event):
if self.queue is None:
raise RuntimeError(
"EventDispatcher.start() must be called before dispatching events"
)
await self.queue.put(event) await self.queue.put(event)
async def _process_events(self): async def _process_events(self):
@@ -197,7 +232,7 @@ class EventDispatcher:
# returns - avoids the race where create_task schedules the callback after # returns - avoids the race where create_task schedules the callback after
# the waiter has already timed out with done=set(). # the waiter has already timed out with done=set().
if asyncio.iscoroutinefunction(subscription.callback): if asyncio.iscoroutinefunction(subscription.callback):
asyncio.create_task(self._execute_callback(subscription.callback, event.clone())) self._spawn_background(self._execute_callback(subscription.callback, event.clone()))
else: else:
try: try:
subscription.callback(event.clone()) subscription.callback(event.clone())
@@ -220,6 +255,8 @@ class EventDispatcher:
async def start(self): async def start(self):
if not self.running: if not self.running:
if self.queue is None:
self.queue = asyncio.Queue()
self.running = True self.running = True
self._task = asyncio.create_task(self._process_events()) self._task = asyncio.create_task(self._process_events())
@@ -227,7 +264,12 @@ class EventDispatcher:
if self.running: if self.running:
self.running = False self.running = False
if self._task: if self._task:
await self.queue.join() if self.queue is not None:
await self.queue.join()
# Wait for any in-flight async callbacks to complete before
# tearing down (F07: task_done fires before callbacks finish).
if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._task.cancel() self._task.cancel()
try: try:
await self._task await self._task

View File

@@ -28,10 +28,17 @@ class MeshCore:
auto_reconnect: bool = False, auto_reconnect: bool = False,
max_reconnect_attempts: int = 3, max_reconnect_attempts: int = 3,
): ):
# Wrap connection with ConnectionManager # Wrap connection with ConnectionManager.
# The reconnect callback ensures send_appstart() runs after every
# transport-level reconnect, which is required by firmware to
# initialize the session (F02).
self.dispatcher = EventDispatcher() self.dispatcher = EventDispatcher()
self.connection_manager = ConnectionManager( self.connection_manager = ConnectionManager(
cx, self.dispatcher, auto_reconnect, max_reconnect_attempts cx,
self.dispatcher,
auto_reconnect,
max_reconnect_attempts,
reconnect_callback=self._on_reconnect,
) )
self.cx = self.connection_manager # For backward compatibility self.cx = self.connection_manager # For backward compatibility
@@ -174,6 +181,15 @@ class MeshCore:
return None return None
return mc return mc
async def _on_reconnect(self):
"""Callback invoked by ConnectionManager after a successful reconnect.
Firmware requires CMD_APP_START after every transport-level connection
to initialize the session. MeshCore.connect() does this on the initial
connection; this callback ensures it also happens on reconnects (F02).
"""
await self.commands.send_appstart()
async def connect(self): async def connect(self):
await self.dispatcher.start() await self.dispatcher.start()
result = await self.connection_manager.connect() result = await self.connection_manager.connect()

View File

@@ -42,6 +42,28 @@ class MeshcorePacketParser:
Returns : Returns :
completed log_data completed log_data
""" """
# Minimum viable payload is 2 bytes (1 header + 1 path_byte) for a
# direct route. Anything shorter is provably broken — for example,
# the LOG_DATA branch in reader.py only requires `len(data) > 3`,
# which means a 4-byte LOG_DATA frame produces a 1-byte payload
# here, and `path_byte = pbuf.read(1)[0]` further down would raise
# IndexError on the empty buffer. Populate sentinel values so the
# caller's downstream `log_data['route_type']` etc. lookups don't
# KeyError, then return early.
if len(payload) < 2:
logger.debug(f"parsePacketPayload: payload too short ({len(payload)} bytes < 2), returning sentinel log_data")
log_data["route_type"] = -1
log_data["route_typename"] = "UNK"
log_data["payload_type"] = -1
log_data["payload_typename"] = "UNK"
log_data["payload_ver"] = 0
log_data["path_len"] = 0
log_data["path_hash_size"] = 1
log_data["path"] = ""
log_data["pkt_payload"] = b""
log_data["pkt_hash"] = 0
return log_data
pbuf = io.BytesIO(payload) pbuf = io.BytesIO(payload)
header = pbuf.read(1)[0] header = pbuf.read(1)[0]
@@ -128,7 +150,7 @@ class MeshcorePacketParser:
uncrypted = cipher.decrypt(msg) uncrypted = cipher.decrypt(msg)
timestamp = int.from_bytes(uncrypted[0:4], "little", signed=False) timestamp = int.from_bytes(uncrypted[0:4], "little", signed=False)
attempt = uncrypted[4] & 3 attempt = uncrypted[4] & 3
txt_type = int.from_bytes(uncrypted[4:4], "little", signed=False) >> 2 txt_type = int.from_bytes(uncrypted[4:5], "little", signed=False) >> 2
message = uncrypted[5:].strip(b"\0") message = uncrypted[5:].strip(b"\0")
msg_hash = int.from_bytes(SHA256.new(timestamp.to_bytes(4, "little", signed=False) + message).digest()[0:4], "little", signed=False) msg_hash = int.from_bytes(SHA256.new(timestamp.to_bytes(4, "little", signed=False) + message).digest()[0:4], "little", signed=False)
log_data["message"] = message.decode("utf-8", "ignore") log_data["message"] = message.decode("utf-8", "ignore")
@@ -149,39 +171,42 @@ class MeshcorePacketParser:
del self.channels_log[:25] del self.channels_log[:25]
elif not payload is None and payload_type == 0x04: # Advert elif not payload is None and payload_type == 0x04: # Advert
pk_buf = io.BytesIO(pkt_payload) try:
adv_key = pk_buf.read(32).hex() pk_buf = io.BytesIO(pkt_payload)
adv_timestamp = int.from_bytes(pk_buf.read(4), "little", signed=False) adv_key = pk_buf.read(32).hex()
signature = pk_buf.read(64).hex() adv_timestamp = int.from_bytes(pk_buf.read(4), "little", signed=False)
flags = pk_buf.read(1)[0] signature = pk_buf.read(64).hex()
adv_type = flags & 0x0F flags = pk_buf.read(1)[0]
adv_lat = None adv_type = flags & 0x0F
adv_lon = None adv_lat = None
adv_feat1 = None adv_lon = None
adv_feat2 = None adv_feat1 = None
if flags & 0x10 > 0: #has location adv_feat2 = None
adv_lat = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0 if flags & 0x10 > 0: #has location
adv_lon = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0 adv_lat = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0
if flags & 0x20 > 0: #has feature1 adv_lon = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0
adv_feat1 = pk_buf.read(2).hex() if flags & 0x20 > 0: #has feature1
if flags & 0x40 > 0: #has feature2 adv_feat1 = pk_buf.read(2).hex()
adv_feat2 = pk_buf.read(2).hex() if flags & 0x40 > 0: #has feature2
if flags & 0x80 > 0: #has name adv_feat2 = pk_buf.read(2).hex()
adv_name = pk_buf.read().decode("utf-8", "ignore").strip("\x00") if flags & 0x80 > 0: #has name
log_data["adv_name"] = adv_name adv_name = pk_buf.read().decode("utf-8", "ignore").strip("\x00")
log_data["adv_name"] = adv_name
log_data["adv_key"] = adv_key log_data["adv_key"] = adv_key
log_data["adv_timestamp"] = adv_timestamp log_data["adv_timestamp"] = adv_timestamp
log_data["signature"] = signature log_data["signature"] = signature
log_data["adv_flags"] = flags log_data["adv_flags"] = flags
log_data["adv_type"] = adv_type log_data["adv_type"] = adv_type
if not adv_lat is None : if not adv_lat is None :
log_data["adv_lat"] = adv_lat log_data["adv_lat"] = adv_lat
if not adv_lon is None : if not adv_lon is None :
log_data["adv_lon"] = adv_lon log_data["adv_lon"] = adv_lon
if not adv_feat1 is None: if not adv_feat1 is None:
log_data["adv_feat1"] = adv_feat1 log_data["adv_feat1"] = adv_feat1
if not adv_feat2 is None: if not adv_feat2 is None:
log_data["adv_feat2"] = adv_feat2 log_data["adv_feat2"] = adv_feat2
except (IndexError, ValueError) as e:
logger.debug(f"parsePacketPayload: malformed ADVERT payload ({type(e).__name__}: {e}), len={len(pkt_payload)}")
return log_data return log_data

View File

@@ -71,6 +71,7 @@ class CommandType(Enum):
SET_AUTOADD_CONFIG = 58 SET_AUTOADD_CONFIG = 58
GET_AUTOADD_CONFIG = 59 GET_AUTOADD_CONFIG = 59
GET_ALLOWED_REPEAT_FREQ = 60 GET_ALLOWED_REPEAT_FREQ = 60
GET_STATS = 56 # R04: CMD_GET_STATS — used by get_stats_core/radio/packets
SET_PATH_HASH_MODE = 61 SET_PATH_HASH_MODE = 61
# Packet prefixes for the protocol # Packet prefixes for the protocol
@@ -120,3 +121,6 @@ class PacketType(Enum):
PATH_DISCOVERY_RESPONSE = 0x8D PATH_DISCOVERY_RESPONSE = 0x8D
CONTROL_DATA = 0x8E CONTROL_DATA = 0x8E
CONTACT_DELETED = 0x8F CONTACT_DELETED = 0x8F
CONTACTS_FULL = 0x90 # N02: MyMesh::onContactsFull() — 1-byte push, no payload
# Note: 0x90 == ControlType.NODE_DISCOVER_RESP in a different namespace.
# Not a literal conflict (PacketType vs ControlType), but a maintenance hazard.

File diff suppressed because it is too large Load Diff

View File

@@ -20,11 +20,19 @@ class SerialConnection:
self._disconnect_callback = None self._disconnect_callback = None
self.cx_dly = cx_dly self.cx_dly = cx_dly
self._connected_event = asyncio.Event() self._connected_event = asyncio.Event()
self._background_tasks: set[asyncio.Task] = set()
self.frame_expected_size = 0 self.frame_expected_size = 0
self.inframe = b"" self.inframe = b""
self.header = b"" self.header = b""
def _spawn_background(self, coro) -> asyncio.Task:
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
task = asyncio.create_task(coro)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return task
class MCSerialClientProtocol(asyncio.Protocol): class MCSerialClientProtocol(asyncio.Protocol):
def __init__(self, cx): def __init__(self, cx):
self.cx = cx self.cx = cx
@@ -44,7 +52,7 @@ class SerialConnection:
self.cx._connected_event.clear() self.cx._connected_event.clear()
if self.cx._disconnect_callback: if self.cx._disconnect_callback:
asyncio.create_task(self.cx._disconnect_callback("serial_disconnect")) self.cx._spawn_background(self.cx._disconnect_callback("serial_disconnect"))
def pause_writing(self): def pause_writing(self):
logger.debug("pause writing") logger.debug("pause writing")
@@ -52,9 +60,13 @@ class SerialConnection:
def resume_writing(self): def resume_writing(self):
logger.debug("resume writing") logger.debug("resume writing")
async def connect(self): async def connect(self, timeout: float = 10.0):
""" """
Connects to the device Connects to the device.
Args:
timeout: Maximum seconds to wait for connection_made callback.
Defaults to 10.0. Raises asyncio.TimeoutError on expiry.
""" """
self._connected_event.clear() self._connected_event.clear()
@@ -66,7 +78,7 @@ class SerialConnection:
baudrate=self.baudrate, baudrate=self.baudrate,
) )
await self._connected_event.wait() await asyncio.wait_for(self._connected_event.wait(), timeout=timeout)
logger.info("Serial Connection started") logger.info("Serial Connection started")
return self.port return self.port
@@ -102,7 +114,7 @@ class SerialConnection:
self.frame_expected_size = 0 self.frame_expected_size = 0
if len(data) > 0: # rerun handle_rx on remaining data if len(data) > 0: # rerun handle_rx on remaining data
self.handle_rx(data) self.handle_rx(data)
return return # nothing left to process after reset
upbound = self.frame_expected_size - len(self.inframe) upbound = self.frame_expected_size - len(self.inframe)
if len(data) < upbound: if len(data) < upbound:
@@ -114,7 +126,7 @@ class SerialConnection:
data = data[upbound:] data = data[upbound:]
if self.reader is not None: if self.reader is not None:
# feed meshcore reader # feed meshcore reader
asyncio.create_task(self.reader.handle_rx(self.inframe)) self._spawn_background(self.reader.handle_rx(self.inframe))
# reset inframe # reset inframe
self.inframe = b"" self.inframe = b""
self.header = b"" self.header = b""
@@ -125,11 +137,18 @@ class SerialConnection:
async def send(self, data): async def send(self, data):
if not self.transport: if not self.transport:
logger.error("Transport not connected, cannot send data") logger.error("Transport not connected, cannot send data")
if self._disconnect_callback:
await self._disconnect_callback("serial_transport_lost")
return return
size = len(data) size = len(data)
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
logger.debug(f"sending pkt : {pkt}") logger.debug(f"sending pkt : {pkt}")
self.transport.write(pkt) try:
self.transport.write(pkt)
except OSError as exc:
logger.warning(f"Serial write failed: {exc}")
if self._disconnect_callback:
await self._disconnect_callback(f"serial_write_failed: {exc}")
async def disconnect(self): async def disconnect(self):
"""Close the serial connection.""" """Close the serial connection."""

View File

@@ -24,6 +24,14 @@ class TCPConnection:
self.frame_expected_size = 0 self.frame_expected_size = 0
self.header = b"" self.header = b""
self.inframe = b"" self.inframe = b""
self._background_tasks: set[asyncio.Task] = set()
def _spawn_background(self, coro) -> asyncio.Task:
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
task = asyncio.create_task(coro)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
return task
class MCClientProtocol(asyncio.Protocol): class MCClientProtocol(asyncio.Protocol):
def __init__(self, cx): def __init__(self, cx):
@@ -38,7 +46,6 @@ class TCPConnection:
def data_received(self, data): def data_received(self, data):
logger.debug("data received") logger.debug("data received")
self.cx._receive_count += 1
self.cx.handle_rx(data) self.cx.handle_rx(data)
def error_received(self, exc): def error_received(self, exc):
@@ -47,7 +54,7 @@ class TCPConnection:
def connection_lost(self, exc): def connection_lost(self, exc):
logger.debug("TCP server closed the connection") logger.debug("TCP server closed the connection")
if self.cx._disconnect_callback: if self.cx._disconnect_callback:
asyncio.create_task(self.cx._disconnect_callback("tcp_disconnect")) self.cx._spawn_background(self.cx._disconnect_callback("tcp_disconnect"))
async def connect(self): async def connect(self):
""" """
@@ -59,10 +66,7 @@ class TCPConnection:
) )
logger.info("TCP Connection started") logger.info("TCP Connection started")
future = asyncio.Future() return self.host
future.set_result(self.host)
return future
def set_reader(self, reader): def set_reader(self, reader):
self.reader = reader self.reader = reader
@@ -96,7 +100,7 @@ class TCPConnection:
self.frame_expected_size = 0 self.frame_expected_size = 0
if len(data) > 0: # rerun handle_rx on remaining data if len(data) > 0: # rerun handle_rx on remaining data
self.handle_rx(data) self.handle_rx(data)
return return # nothing left to process after reset
upbound = self.frame_expected_size - len(self.inframe) upbound = self.frame_expected_size - len(self.inframe)
if len(data) < upbound : if len(data) < upbound :
@@ -106,9 +110,13 @@ class TCPConnection:
self.inframe = self.inframe + data[0:upbound] self.inframe = self.inframe + data[0:upbound]
data = data[upbound:] data = data[upbound:]
# Increment per completed MeshCore frame, not per TCP segment (N04).
# The threshold heuristic in send() compares _send_count vs
# _receive_count — counting per-segment skews it under fragmentation.
self._receive_count += 1
if self.reader is not None: if self.reader is not None:
# feed meshcore reader # feed meshcore reader
asyncio.create_task(self.reader.handle_rx(self.inframe)) self._spawn_background(self.reader.handle_rx(self.inframe))
# reset inframe # reset inframe
self.inframe = b"" self.inframe = b""
self.header = b"" self.header = b""
@@ -137,7 +145,12 @@ class TCPConnection:
size = len(data) size = len(data)
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
logger.debug(f"sending pkt : {pkt}") logger.debug(f"sending pkt : {pkt}")
self.transport.write(pkt) try:
self.transport.write(pkt)
except (OSError, ConnectionResetError) as exc:
logger.warning(f"TCP write failed: {exc}")
if self._disconnect_callback:
await self._disconnect_callback(f"tcp_write_failed: {exc}")
async def disconnect(self): async def disconnect(self):
"""Close the TCP connection.""" """Close the TCP connection."""

View File

@@ -37,7 +37,7 @@ class TestBLEPinPairing(unittest.TestCase):
@patch("meshcore.ble_cx.BleakClient") @patch("meshcore.ble_cx.BleakClient")
def test_ble_connection_with_pin_failed_pairing(self, mock_bleak_client): def test_ble_connection_with_pin_failed_pairing(self, mock_bleak_client):
"""Test BLE connection with PIN when pairing fails but connection continues""" """Test BLE connection with PIN when pairing fails — re-raises (F17)."""
# Arrange # Arrange
mock_client_instance = self._get_mock_bleak_client() mock_client_instance = self._get_mock_bleak_client()
mock_client_instance.pair = AsyncMock(side_effect=Exception("Pairing failed")) mock_client_instance.pair = AsyncMock(side_effect=Exception("Pairing failed"))
@@ -47,17 +47,16 @@ class TestBLEPinPairing(unittest.TestCase):
pin = "123456" pin = "123456"
ble_conn = BLEConnection(address=address, pin=pin) ble_conn = BLEConnection(address=address, pin=pin)
# Act # Act & Assert — pairing failure now re-raises instead of being
result = asyncio.run(ble_conn.connect()) # swallowed, because a half-usable transport is worse than a clean
# failure (forensics finding F17).
# Assert with self.assertRaises(Exception) as ctx:
asyncio.run(ble_conn.connect())
self.assertIn("Pairing failed", str(ctx.exception))
mock_client_instance.connect.assert_called_once() mock_client_instance.connect.assert_called_once()
mock_client_instance.pair.assert_called_once() mock_client_instance.pair.assert_called_once()
mock_client_instance.start_notify.assert_called_once_with( # disconnect should be called to clean up the failed connection
UART_TX_CHAR_UUID, ble_conn.handle_rx mock_client_instance.disconnect.assert_called_once()
)
# Connection should still succeed even if pairing fails
self.assertEqual(result, address)
@patch("meshcore.ble_cx.BleakClient") @patch("meshcore.ble_cx.BleakClient")
def test_ble_connection_without_pin_no_pairing(self, mock_bleak_client): def test_ble_connection_without_pin_no_pairing(self, mock_bleak_client):

View File

@@ -0,0 +1,235 @@
"""
Verification tests for asyncio lifecycle fixes.
"""
import asyncio
import gc
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
from meshcore.events import Event, EventDispatcher, EventType
from meshcore.tcp_cx import TCPConnection
from meshcore.serial_cx import SerialConnection
from meshcore.commands.base import CommandHandlerBase
class TestBackgroundTaskTracking(unittest.TestCase):
"""Fire-and-forget create_task calls must be tracked to prevent GC."""
def test_tcp_spawn_background_retains_task(self):
"""TCP _spawn_background adds the task to _background_tasks."""
async def _run():
cx = TCPConnection("127.0.0.1", 5555)
completed = asyncio.Event()
async def dummy():
completed.set()
task = cx._spawn_background(dummy())
assert task in cx._background_tasks
await completed.wait()
# After completion, done_callback should have discarded it
await asyncio.sleep(0) # let done callback fire
assert task not in cx._background_tasks
asyncio.run(_run())
def test_serial_spawn_background_retains_task(self):
"""Serial _spawn_background adds the task to _background_tasks."""
async def _run():
with patch("meshcore.serial_cx.asyncio.Event") as mock_event:
mock_event.return_value = MagicMock()
cx = SerialConnection("/dev/null", 115200)
completed = asyncio.Event()
async def dummy():
completed.set()
task = cx._spawn_background(dummy())
assert task in cx._background_tasks
await completed.wait()
await asyncio.sleep(0)
assert task not in cx._background_tasks
asyncio.run(_run())
def test_event_dispatcher_spawn_background_retains_task(self):
"""EventDispatcher _spawn_background adds task to _background_tasks."""
async def _run():
dispatcher = EventDispatcher()
completed = asyncio.Event()
async def dummy():
completed.set()
task = dispatcher._spawn_background(dummy())
assert task in dispatcher._background_tasks
await completed.wait()
await asyncio.sleep(0)
assert task not in dispatcher._background_tasks
asyncio.run(_run())
def test_tcp_handle_rx_uses_tracked_task(self):
"""TCP handle_rx dispatches reader.handle_rx via _spawn_background."""
async def _run():
cx = TCPConnection("127.0.0.1", 5555)
reader = AsyncMock()
reader.handle_rx = AsyncMock()
cx.set_reader(reader)
# Build a minimal valid frame: 0x3e + 2-byte LE size + payload
payload = b"\x01\x02\x03"
size = len(payload).to_bytes(2, "little")
frame = b"\x3e" + size + payload
cx.handle_rx(frame)
# Task should be tracked
assert len(cx._background_tasks) == 1
# Let task complete
await asyncio.sleep(0.05)
reader.handle_rx.assert_awaited_once_with(payload)
asyncio.run(_run())
def test_tcp_connection_lost_uses_tracked_task(self):
"""TCP connection_lost dispatches disconnect callback via _spawn_background."""
async def _run():
cx = TCPConnection("127.0.0.1", 5555)
callback = AsyncMock()
cx.set_disconnect_callback(callback)
protocol = cx.MCClientProtocol(cx)
protocol.connection_lost(None)
assert len(cx._background_tasks) == 1
await asyncio.sleep(0.05)
callback.assert_awaited_once_with("tcp_disconnect")
asyncio.run(_run())
def test_gc_does_not_cancel_tracked_tasks(self):
"""Tracked tasks survive GC pressure (the whole point of tracking)."""
async def _run():
cx = TCPConnection("127.0.0.1", 5555)
result = []
async def slow_task():
await asyncio.sleep(0.05)
result.append("done")
cx._spawn_background(slow_task())
# Force GC — untracked tasks could be collected here
gc.collect()
await asyncio.sleep(0.1)
assert result == ["done"]
asyncio.run(_run())
class TestTaskDoneCorrectness(unittest.TestCase):
"""EventDispatcher.stop() must wait for in-flight async callbacks."""
def test_stop_waits_for_async_callbacks(self):
"""stop() should not return until async callbacks have completed."""
async def _run():
dispatcher = EventDispatcher()
await dispatcher.start()
callback_completed = False
async def slow_callback(event):
nonlocal callback_completed
await asyncio.sleep(0.1)
callback_completed = True
dispatcher.subscribe(EventType.OK, slow_callback)
await dispatcher.dispatch(Event(EventType.OK, {}))
# Give the dispatch loop a moment to pick up the event
await asyncio.sleep(0.02)
# stop() should wait for slow_callback to finish
await dispatcher.stop()
assert callback_completed, "stop() returned before async callback completed"
asyncio.run(_run())
class TestDeferredPrimitiveConstruction(unittest.TestCase):
"""Queue and Lock must not bind to import-time loop."""
def test_event_dispatcher_queue_is_none_before_start(self):
"""EventDispatcher.queue should be None until start() is called."""
dispatcher = EventDispatcher()
assert dispatcher.queue is None
def test_event_dispatcher_queue_created_on_start(self):
"""start() creates the queue."""
async def _run():
dispatcher = EventDispatcher()
assert dispatcher.queue is None
await dispatcher.start()
assert dispatcher.queue is not None
assert isinstance(dispatcher.queue, asyncio.Queue)
await dispatcher.stop()
asyncio.run(_run())
def test_event_dispatcher_dispatch_before_start_raises(self):
"""dispatch() before start() should raise RuntimeError."""
async def _run():
dispatcher = EventDispatcher()
with self.assertRaises(RuntimeError):
await dispatcher.dispatch(Event(EventType.OK, {}))
asyncio.run(_run())
def test_command_handler_lock_is_none_before_use(self):
"""CommandHandlerBase lock should be None until first access."""
handler = CommandHandlerBase()
assert handler._CommandHandlerBase__mesh_request_lock is None
def test_command_handler_lock_created_on_access(self):
"""Accessing _mesh_request_lock creates it lazily."""
async def _run():
handler = CommandHandlerBase()
lock = handler._mesh_request_lock
assert isinstance(lock, asyncio.Lock)
# Second access returns same instance
assert handler._mesh_request_lock is lock
asyncio.run(_run())
class TestGetRunningLoop(unittest.TestCase):
"""get_event_loop() replaced with get_running_loop() in send()."""
def test_send_uses_get_running_loop(self):
"""send() should call get_running_loop, not get_event_loop."""
async def _run():
handler = CommandHandlerBase()
dispatcher = EventDispatcher()
await dispatcher.start()
handler.set_dispatcher(dispatcher)
mock_sender = AsyncMock()
handler._sender_func = mock_sender
# Patch get_running_loop to verify it's called
with patch("meshcore.commands.base.asyncio.get_running_loop", wraps=asyncio.get_running_loop) as mock_grl:
# send with expected_events triggers the loop = asyncio.get_running_loop() path
result = await handler.send(
b"\x01",
expected_events=[EventType.OK],
timeout=0.05,
)
mock_grl.assert_called()
await dispatcher.stop()
asyncio.run(_run())
if __name__ == "__main__":
unittest.main()

View File

@@ -28,6 +28,10 @@ def mock_dispatcher():
sub.unsubscribe = MagicMock() sub.unsubscribe = MagicMock()
dispatcher._last_subscribe_handler = handler dispatcher._last_subscribe_handler = handler
dispatcher._last_subscribe_event_type = event_type dispatcher._last_subscribe_event_type = event_type
# Immediately resolve the future so send() doesn't block
asyncio.get_event_loop().call_soon(
handler, Event(event_type, {})
)
return sub return sub
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe) dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
@@ -80,6 +84,13 @@ async def test_send_with_event(command_handler, mock_connection, mock_dispatcher
async def test_send_timeout(command_handler, mock_connection, mock_dispatcher): async def test_send_timeout(command_handler, mock_connection, mock_dispatcher):
# Override to NOT resolve events, so we can test the timeout path
def non_resolving_subscribe(event_type, handler, attribute_filters=None):
sub = MagicMock(spec=Subscription)
sub.unsubscribe = MagicMock()
return sub
mock_dispatcher.subscribe = MagicMock(side_effect=non_resolving_subscribe)
result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1) result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1)
assert result.type == EventType.ERROR assert result.type == EventType.ERROR
assert result.payload == {"reason": "no_event_received"} assert result.payload == {"reason": "no_event_received"}

View File

@@ -0,0 +1,294 @@
"""Tests for reconnect-path fixes."""
import asyncio
import pytest
from meshcore.connection_manager import ConnectionManager
from meshcore.events import Event, EventDispatcher, EventType
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class FakeConnection:
"""Minimal stub that satisfies ConnectionProtocol."""
def __init__(self, connect_results=None):
"""
Args:
connect_results: iterator of return values for successive
connect() calls. ``None`` means soft failure; a string
means success; raising is also supported via sentinel.
"""
self._connect_results = list(connect_results or ["ok"])
self._call_index = 0
self.reader = None
async def connect(self):
if self._call_index < len(self._connect_results):
result = self._connect_results[self._call_index]
self._call_index += 1
else:
result = self._connect_results[-1]
if isinstance(result, Exception):
raise result
return result
async def disconnect(self):
pass
async def send(self, data):
pass
def set_reader(self, reader):
self.reader = reader
class RaisingConnection(FakeConnection):
"""Connection that raises on every connect() attempt."""
def __init__(self, exc=None):
super().__init__()
self._exc = exc or ConnectionError("boom")
async def connect(self):
raise self._exc
class _EventCollector:
"""Subscribes to all events and records them."""
def __init__(self, dispatcher: EventDispatcher):
self.events: list[Event] = []
dispatcher.subscribe(None, self._on_event)
async def _on_event(self, event: Event):
self.events.append(event)
# ---------------------------------------------------------------------------
# TCP connect() should return a plain value, not an asyncio.Future
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tcp_connect_returns_plain_string():
"""TCPConnection.connect() returns self.host (a plain string), not an
asyncio.Future. We test indirectly via ConnectionManager — the
CONNECTED event payload should contain a plain string, not a Future
object."""
conn = FakeConnection(connect_results=["10.0.0.1"])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
collector = _EventCollector(dispatcher)
mgr = ConnectionManager(conn, dispatcher)
result = await mgr.connect()
assert result == "10.0.0.1"
# Give the dispatcher a moment to deliver the event
await asyncio.sleep(0.05)
connected_events = [e for e in collector.events if e.type == EventType.CONNECTED]
assert len(connected_events) == 1
payload = connected_events[0].payload
assert payload["connection_info"] == "10.0.0.1"
# The payload value must NOT be an asyncio.Future
assert not isinstance(payload["connection_info"], asyncio.Future)
finally:
await dispatcher.stop()
# ---------------------------------------------------------------------------
# Reconnect attempts must not compound (no tail-recursive create_task)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reconnect_loop_does_not_compound():
"""_attempt_reconnect must use a single iterative loop. After
max_reconnect_attempts failures, exactly that many connect() calls
should have been made — no exponential fan-out from orphaned tasks."""
# All attempts fail (return None)
conn = FakeConnection(connect_results=[None, None, None, None])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
collector = _EventCollector(dispatcher)
mgr = ConnectionManager(
conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=3,
)
mgr._is_connected = True # simulate a live connection
await mgr.handle_disconnect("test_disconnect")
# Wait for the reconnect loop to exhaust all attempts
# (3 attempts × 1s sleep each, but we can just await the task)
if mgr._reconnect_task:
await mgr._reconnect_task
# Exactly 3 connect() calls should have been made
assert conn._call_index == 3
# A DISCONNECTED event with max_attempts_exceeded should have fired
await asyncio.sleep(0.05)
disconnected = [e for e in collector.events if e.type == EventType.DISCONNECTED]
assert len(disconnected) == 1
assert disconnected[0].payload.get("max_attempts_exceeded") is True
finally:
await dispatcher.stop()
@pytest.mark.asyncio
async def test_disconnect_cancels_reconnect_loop():
"""disconnect() during an active reconnect loop must cancel the
single task cleanly — no orphaned tasks left running."""
# Simulate a connection that always fails (returns None), giving us
# time to call disconnect() mid-loop.
conn = FakeConnection(connect_results=[None, None, None, None, None])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
mgr = ConnectionManager(
conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=5,
)
mgr._is_connected = True
await mgr.handle_disconnect("test_disconnect")
# Let the first attempt start (wait just past the 1s sleep)
await asyncio.sleep(1.2)
assert conn._call_index >= 1 # at least one attempt made
# Now disconnect — should cancel the loop
await mgr.disconnect()
assert mgr._reconnect_task is None
calls_at_cancel = conn._call_index
# Wait a bit and confirm no more attempts happened
await asyncio.sleep(2)
assert conn._call_index == calls_at_cancel
finally:
await dispatcher.stop()
# ---------------------------------------------------------------------------
# reconnect_callback (send_appstart) is called after reconnect
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reconnect_callback_called_after_reconnect():
"""When ConnectionManager reconnects successfully, the
reconnect_callback (e.g. send_appstart) must be invoked."""
callback_called = []
async def fake_appstart():
callback_called.append(True)
# First connect() fails (None), second succeeds
conn = FakeConnection(connect_results=[None, "10.0.0.1"])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
mgr = ConnectionManager(
conn, dispatcher,
auto_reconnect=True,
max_reconnect_attempts=3,
reconnect_callback=fake_appstart,
)
mgr._is_connected = True
await mgr.handle_disconnect("test_disconnect")
if mgr._reconnect_task:
await mgr._reconnect_task
assert len(callback_called) == 1
finally:
await dispatcher.stop()
@pytest.mark.asyncio
async def test_reconnect_callback_failure_does_not_crash_loop():
"""If the reconnect_callback raises, the reconnect still counts as
successful (transport is up) — the callback failure is logged but
does not crash the loop or leave the manager in a broken state."""
async def failing_callback():
raise RuntimeError("appstart failed")
# connect() succeeds on first attempt
conn = FakeConnection(connect_results=["10.0.0.1"])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
collector = _EventCollector(dispatcher)
mgr = ConnectionManager(
conn, dispatcher,
auto_reconnect=True,
max_reconnect_attempts=3,
reconnect_callback=failing_callback,
)
mgr._is_connected = True
await mgr.handle_disconnect("test_disconnect")
if mgr._reconnect_task:
await mgr._reconnect_task
# Despite callback failure, CONNECTED event should have fired
await asyncio.sleep(0.05)
connected = [e for e in collector.events if e.type == EventType.CONNECTED]
assert len(connected) == 1
assert mgr._is_connected is True
finally:
await dispatcher.stop()
# ---------------------------------------------------------------------------
# connect() returning None is a soft failure (BLE scan miss)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_connect_none_is_soft_failure():
"""When connect() returns None (e.g. BLE scan found no device),
ConnectionManager.connect() should NOT set _is_connected and should
NOT emit a CONNECTED event."""
conn = FakeConnection(connect_results=[None])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
collector = _EventCollector(dispatcher)
mgr = ConnectionManager(conn, dispatcher)
result = await mgr.connect()
assert result is None
assert mgr._is_connected is False
await asyncio.sleep(0.05)
connected = [e for e in collector.events if e.type == EventType.CONNECTED]
assert len(connected) == 0
finally:
await dispatcher.stop()
@pytest.mark.asyncio
async def test_no_reconnect_callback_is_noop():
"""When no reconnect_callback is provided (backwards compat for
direct ConnectionManager users), reconnect should still work."""
conn = FakeConnection(connect_results=["10.0.0.1"])
dispatcher = EventDispatcher()
await dispatcher.start()
try:
mgr = ConnectionManager(
conn, dispatcher,
auto_reconnect=True,
max_reconnect_attempts=3,
# No reconnect_callback — default None
)
mgr._is_connected = True
await mgr.handle_disconnect("test_disconnect")
if mgr._reconnect_task:
await mgr._reconnect_task
assert mgr._is_connected is True
finally:
await dispatcher.stop()

View File

@@ -0,0 +1,236 @@
"""Verification tests for error response handling fixes.
The tests confirm that error responses are surfaced cleanly instead
of causing KeyError, TypeError, NameError, or silent fallthrough.
"""
import asyncio
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from meshcore.commands import CommandHandler
from meshcore.events import EventType, Event, Subscription
pytestmark = pytest.mark.asyncio
VALID_PUBKEY_HEX = "0123456789abcdef" * 4 # 64 hex chars = 32 bytes
# ── Fixtures ───────────────────────────────────────────────────────
@pytest.fixture
def mock_connection():
connection = MagicMock()
connection.send = AsyncMock()
return connection
@pytest.fixture
def mock_dispatcher():
dispatcher = MagicMock()
dispatcher.wait_for_event = AsyncMock()
dispatcher.dispatch = AsyncMock()
def fake_subscribe(event_type, handler, attribute_filters=None):
sub = MagicMock(spec=Subscription)
sub.unsubscribe = MagicMock()
dispatcher._last_subscribe_handler = handler
dispatcher._last_subscribe_event_type = event_type
return sub
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
return dispatcher
@pytest.fixture
def command_handler(mock_connection, mock_dispatcher):
handler = CommandHandler()
async def sender(data):
await mock_connection.send(data)
handler._sender_func = sender
handler.dispatcher = mock_dispatcher
return handler
def setup_error_response(mock_dispatcher):
"""Configure dispatcher to return an ERROR event for any subscribe."""
def fake_subscribe(evt_type, handler, attr_filters=None):
sub = MagicMock(spec=Subscription)
sub.unsubscribe = MagicMock()
# Always fire ERROR regardless of which event type was subscribed
if evt_type == EventType.ERROR:
asyncio.get_event_loop().call_soon(
handler, Event(EventType.ERROR, {"reason": "test_error"})
)
return sub
mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
def setup_event_response(mock_dispatcher, event_type, payload):
"""Configure dispatcher to return a specific event."""
def fake_subscribe(evt_type, handler, attr_filters=None):
sub = MagicMock(spec=Subscription)
sub.unsubscribe = MagicMock()
if evt_type == event_type:
asyncio.get_event_loop().call_soon(
handler, Event(event_type, payload)
)
return sub
mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
# ── Event.is_error() helper ──────────────────────────────────
async def test_event_is_error_true():
"""is_error() returns True for ERROR events."""
event = Event(EventType.ERROR, {"reason": "test"})
assert event.is_error() is True
async def test_event_is_error_false():
"""is_error() returns False for non-ERROR events."""
event = Event(EventType.OK, {})
assert event.is_error() is False
event2 = Event(EventType.SELF_INFO, {"name": "test"})
assert event2.is_error() is False
# ── send_msg_with_retry continues on ERROR ──────────────
async def test_send_msg_with_retry_error_no_keyerror(
command_handler, mock_dispatcher
):
"""send_msg_with_retry returns None (exhausted retries) on
persistent ERROR instead of raising KeyError on missing 'expected_ack'."""
setup_error_response(mock_dispatcher)
# Provide a mock contact so the path logic doesn't interfere
command_handler._get_contact_by_prefix = MagicMock(return_value=None)
# max_attempts=2 so it retries once then gives up
result = await command_handler.send_msg_with_retry(
VALID_PUBKEY_HEX, "hello", max_attempts=2, timeout=0.1
)
# Should return None (no ACK received) rather than raising KeyError
assert result is None
# ── send_appstart includes ERROR in expected events ──────────
async def test_send_appstart_returns_error(
command_handler, mock_dispatcher
):
"""send_appstart returns ERROR event instead of hanging on timeout."""
setup_error_response(mock_dispatcher)
result = await command_handler.send_appstart()
assert result.type == EventType.ERROR
assert result.is_error() is True
assert result.payload["reason"] == "test_error"
# ── device setters return ERROR from send_appstart ───────────
async def test_set_telemetry_mode_base_error(
command_handler, mock_dispatcher
):
"""set_telemetry_mode_base returns ERROR instead of KeyError."""
setup_error_response(mock_dispatcher)
result = await command_handler.set_telemetry_mode_base(1)
assert result.is_error()
assert result.payload["reason"] == "test_error"
async def test_set_telemetry_mode_loc_error(
command_handler, mock_dispatcher
):
"""set_telemetry_mode_loc returns ERROR instead of KeyError."""
setup_error_response(mock_dispatcher)
result = await command_handler.set_telemetry_mode_loc(1)
assert result.is_error()
async def test_set_telemetry_mode_env_error(
command_handler, mock_dispatcher
):
"""set_telemetry_mode_env returns ERROR instead of KeyError."""
setup_error_response(mock_dispatcher)
result = await command_handler.set_telemetry_mode_env(1)
assert result.is_error()
async def test_set_manual_add_contacts_error(
command_handler, mock_dispatcher
):
"""set_manual_add_contacts returns ERROR instead of KeyError."""
setup_error_response(mock_dispatcher)
result = await command_handler.set_manual_add_contacts(True)
assert result.is_error()
async def test_set_advert_loc_policy_error(
command_handler, mock_dispatcher
):
"""set_advert_loc_policy returns ERROR instead of KeyError."""
setup_error_response(mock_dispatcher)
result = await command_handler.set_advert_loc_policy(1)
assert result.is_error()
async def test_set_multi_acks_error(
command_handler, mock_dispatcher
):
"""set_multi_acks returns ERROR instead of KeyError."""
setup_error_response(mock_dispatcher)
result = await command_handler.set_multi_acks(1)
assert result.is_error()
# ── send_anon_req returns ERROR on contact not found ─────────
async def test_send_anon_req_contact_not_found(
command_handler, mock_dispatcher
):
"""send_anon_req returns ERROR event when contact prefix not found,
instead of raising TypeError on NoneType subscript."""
command_handler._get_contact_by_prefix = MagicMock(return_value=None)
result = await command_handler.send_anon_req(
VALID_PUBKEY_HEX, MagicMock(value=1)
)
assert result.is_error()
assert result.payload["reason"] == "contact_not_found"
# ── send_trace handles unknown path_hash_len without NameError ──
async def test_send_trace_unknown_path_hash_len(
command_handler, mock_connection, mock_dispatcher
):
"""send_trace with a path whose segments don't match any known
path_hash_len returns ERROR cleanly instead of NameError on 'e'."""
# 5-char hex segments → path_hash_len = 2.5 → doesn't match 1,2,4,8
result = await command_handler.send_trace(
auth_code=0, tag=1, flags=None, path="abcde"
)
assert result.is_error()
assert result.payload["reason"] == "invalid_path_format"

View File

@@ -0,0 +1,364 @@
"""Verification tests for protocol surface gaps.
Each test constructs a mock firmware frame and verifies the SDK dispatches
the correct EventType with the expected payload fields.
"""
import asyncio
import struct
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from meshcore.events import Event, EventType, EventDispatcher
from meshcore.reader import MessageReader
from meshcore.packets import PacketType, CommandType
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_reader():
"""Create a MessageReader with a mock dispatcher that records dispatched events."""
dispatcher = MagicMock(spec=EventDispatcher)
dispatched = []
async def _capture(event):
dispatched.append(event)
dispatcher.dispatch = AsyncMock(side_effect=_capture)
reader = MessageReader(dispatcher)
return reader, dispatched
# ---------------------------------------------------------------------------
# CONTACT_DELETED handler
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_contact_deleted_dispatches_event():
"""A 33-byte CONTACT_DELETED frame dispatches EventType.CONTACT_DELETED."""
reader, dispatched = _make_reader()
pubkey = bytes(range(32))
frame = bytes([PacketType.CONTACT_DELETED.value]) + pubkey
assert len(frame) == 33
await reader.handle_rx(bytearray(frame))
assert len(dispatched) == 1
evt = dispatched[0]
assert evt.type == EventType.CONTACT_DELETED
assert evt.payload["pubkey"] == pubkey.hex()
assert evt.attributes["pubkey"] == pubkey.hex()
@pytest.mark.asyncio
async def test_contact_deleted_short_frame_ignored():
"""A CONTACT_DELETED frame shorter than 33 bytes is silently dropped."""
reader, dispatched = _make_reader()
# Only 10 bytes — too short
frame = bytes([PacketType.CONTACT_DELETED.value]) + b"\x00" * 9
await reader.handle_rx(bytearray(frame))
assert len(dispatched) == 0
# ---------------------------------------------------------------------------
# CONTACTS_FULL handler + enum entry
# ---------------------------------------------------------------------------
def test_contacts_full_enum_exists():
"""PacketType.CONTACTS_FULL == 0x90."""
assert PacketType.CONTACTS_FULL.value == 0x90
@pytest.mark.asyncio
async def test_contacts_full_dispatches_event():
"""A 1-byte CONTACTS_FULL push dispatches EventType.CONTACTS_FULL."""
reader, dispatched = _make_reader()
frame = bytes([PacketType.CONTACTS_FULL.value])
await reader.handle_rx(bytearray(frame))
assert len(dispatched) == 1
evt = dispatched[0]
assert evt.type == EventType.CONTACTS_FULL
assert evt.payload == {}
# ---------------------------------------------------------------------------
# TUNING_PARAMS handler
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tuning_params_dispatches_event():
"""A 9-byte TUNING_PARAMS frame dispatches with rx_delay and airtime_factor."""
reader, dispatched = _make_reader()
rx_delay = 500
airtime_factor = 200
frame = (
bytes([PacketType.TUNING_PARAMS.value])
+ rx_delay.to_bytes(4, "little")
+ airtime_factor.to_bytes(4, "little")
)
assert len(frame) == 9
await reader.handle_rx(bytearray(frame))
assert len(dispatched) == 1
evt = dispatched[0]
assert evt.type == EventType.TUNING_PARAMS
assert evt.payload["rx_delay"] == 500
assert evt.payload["airtime_factor"] == 200
@pytest.mark.asyncio
async def test_tuning_params_short_frame_dispatches_error():
"""A TUNING_PARAMS frame shorter than 9 bytes dispatches ERROR."""
reader, dispatched = _make_reader()
# Only 5 bytes — too short
frame = bytes([PacketType.TUNING_PARAMS.value]) + b"\x01\x00\x00\x00"
await reader.handle_rx(bytearray(frame))
assert len(dispatched) == 1
evt = dispatched[0]
assert evt.type == EventType.ERROR
assert evt.payload["reason"] == "invalid_frame_length"
# ---------------------------------------------------------------------------
# send_trace() one-byte pad
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_trace_empty_path_pads_to_11_bytes():
"""send_trace() with no path produces an 11-byte packet (not 10)."""
from meshcore.commands.messaging import MessagingCommands
cmd = MessagingCommands.__new__(MessagingCommands)
captured_data = None
async def mock_send(data, expected_events, timeout=None):
nonlocal captured_data
captured_data = bytes(data)
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
cmd.send = mock_send
await cmd.send_trace(auth_code=0, tag=1, flags=0, path=None)
assert captured_data is not None
# cmd(1) + tag(4) + auth(4) + flags(1) + pad(1) = 11
assert len(captured_data) == 11
assert captured_data[-1] == 0x00 # The pad byte
@pytest.mark.asyncio
async def test_send_trace_with_path_no_padding():
"""send_trace() with a non-empty path does NOT add padding."""
from meshcore.commands.messaging import MessagingCommands
cmd = MessagingCommands.__new__(MessagingCommands)
captured_data = None
async def mock_send(data, expected_events, timeout=None):
nonlocal captured_data
captured_data = bytes(data)
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
cmd.send = mock_send
# 2-byte path hash (flags=1 means hash_len=2)
await cmd.send_trace(auth_code=0, tag=1, flags=1, path=b"\xAA\xBB")
assert captured_data is not None
# cmd(1) + tag(4) + auth(4) + flags(1) + path(2) = 12 — no pad needed
assert len(captured_data) == 12
# ---------------------------------------------------------------------------
# Command wrapper: send_raw_data
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_raw_data_wrapper():
"""send_raw_data sends CMD 0x19 + payload."""
from meshcore.commands.messaging import MessagingCommands
cmd = MessagingCommands.__new__(MessagingCommands)
captured_data = None
async def mock_send(data, expected_events, timeout=None):
nonlocal captured_data
captured_data = bytes(data)
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
cmd.send = mock_send
await cmd.send_raw_data(b"\xDE\xAD")
assert captured_data is not None
assert captured_data[0] == 0x19 # CMD_SEND_RAW_DATA
assert captured_data[1:] == b"\xDE\xAD"
# ---------------------------------------------------------------------------
# Command wrapper: has_connection
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_has_connection_wrapper():
"""has_connection sends CMD 0x1c."""
from meshcore.commands.device import DeviceCommands
cmd = DeviceCommands.__new__(DeviceCommands)
captured_data = None
async def mock_send(data, expected_events, timeout=None):
nonlocal captured_data
captured_data = bytes(data)
return Event(EventType.OK, {"value": 1})
cmd.send = mock_send
await cmd.has_connection()
assert captured_data is not None
assert captured_data == b"\x1c"
# ---------------------------------------------------------------------------
# Command wrapper: get_tuning
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_tuning_wrapper():
"""get_tuning sends CMD 0x2b (GET_TUNING_PARAMS = 43)."""
from meshcore.commands.device import DeviceCommands
cmd = DeviceCommands.__new__(DeviceCommands)
captured_data = None
async def mock_send(data, expected_events, timeout=None):
nonlocal captured_data
captured_data = bytes(data)
return Event(EventType.TUNING_PARAMS, {"rx_delay": 500, "airtime_factor": 200})
cmd.send = mock_send
result = await cmd.get_tuning()
assert captured_data == b"\x2b"
assert result.type == EventType.TUNING_PARAMS
# ---------------------------------------------------------------------------
# Command wrapper: get_contact_by_key
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_contact_by_key_wrapper():
"""get_contact_by_key sends CMD 0x1e + 32-byte pubkey."""
from meshcore.commands.contact import ContactCommands
cmd = ContactCommands.__new__(ContactCommands)
captured_data = None
async def mock_send(data, expected_events, timeout=None):
nonlocal captured_data
captured_data = bytes(data)
return Event(EventType.NEXT_CONTACT, {"public_key": "ab" * 32})
cmd.send = mock_send
pubkey = bytes(range(32))
await cmd.get_contact_by_key(pubkey)
assert captured_data is not None
assert captured_data[0] == 0x1E # CMD_GET_CONTACT_BY_KEY
assert captured_data[1:] == pubkey
# ---------------------------------------------------------------------------
# Command wrapper: factory_reset (two-step)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_factory_reset_two_step():
"""factory_reset requires a token from request_factory_reset."""
from meshcore.commands.device import DeviceCommands
cmd = DeviceCommands.__new__(DeviceCommands)
captured_data = None
async def mock_send(data, expected_events, timeout=None):
nonlocal captured_data
captured_data = bytes(data)
return Event(EventType.OK, {})
cmd.send = mock_send
# Step 1: request token
token = await cmd.request_factory_reset()
assert isinstance(token, str)
assert len(token) == 16 # hex-encoded 8 bytes
# Step 2: confirm with wrong token fails
with pytest.raises(ValueError, match="Invalid or expired"):
await cmd.confirm_factory_reset("wrong_token")
# Step 2: confirm with correct token succeeds
await cmd.confirm_factory_reset(token)
assert captured_data == b"\x33" # CMD_FACTORY_RESET
@pytest.mark.asyncio
async def test_factory_reset_without_request_fails():
"""confirm_factory_reset without request_factory_reset raises ValueError."""
from meshcore.commands.device import DeviceCommands
cmd = DeviceCommands.__new__(DeviceCommands)
with pytest.raises(ValueError, match="Invalid or expired"):
await cmd.confirm_factory_reset("any_token")
# ---------------------------------------------------------------------------
# GET_STATS enum entry
# ---------------------------------------------------------------------------
def test_get_stats_enum_exists():
"""CommandType.GET_STATS == 56."""
assert CommandType.GET_STATS.value == 56
@pytest.mark.asyncio
async def test_get_stats_core_uses_enum():
"""get_stats_core sends CommandType.GET_STATS.value (0x38) + 0x00."""
from meshcore.commands.device import DeviceCommands
cmd = DeviceCommands.__new__(DeviceCommands)
captured_data = None
async def mock_send(data, expected_events, timeout=None):
nonlocal captured_data
captured_data = bytes(data)
return Event(EventType.STATS_CORE, {})
cmd.send = mock_send
await cmd.get_stats_core()
assert captured_data is not None
assert captured_data[0] == CommandType.GET_STATS.value # 0x38 = 56
assert captured_data[1] == 0x00

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import asyncio import asyncio
import logging
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
from meshcore.events import EventType from meshcore.events import EventType
from meshcore.reader import MessageReader from meshcore.reader import MessageReader
@@ -88,3 +89,192 @@ async def test_binary_response():
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(test_binary_response()) asyncio.run(test_binary_response())
# ---------------------------------------------------------------------------
# Reader/parser crash-safety verification tests
# ---------------------------------------------------------------------------
class _CapturingDispatcher:
"""Quiet dispatcher that records every dispatched event."""
def __init__(self):
self.events = []
async def dispatch(self, event):
self.events.append(event)
@pytest.mark.asyncio
async def test_handle_rx_malformed_frame_logged_and_swallowed(caplog):
"""Malformed frame must not propagate, must be logged with traceback."""
dispatcher = _CapturingDispatcher()
reader = MessageReader(dispatcher)
# 4-byte CHANNEL_MSG_RECV_V3 frame: type byte (0x11) + 1 SNR byte +
# 2 reserved bytes, but no channel_idx byte. The handler will raise
# IndexError on the next dbuf.read(1)[0] when the buffer is empty.
# The umbrella try/except must catch it, log the parse error, and
# return cleanly.
malformed = bytearray.fromhex("11100000")
with caplog.at_level(logging.ERROR, logger="meshcore"):
await reader.handle_rx(malformed) # must not raise
error_records = [r for r in caplog.records if "handle_rx parse error" in r.message]
assert error_records, (
f"Expected an error log containing 'handle_rx parse error'; "
f"got: {[r.message for r in caplog.records]}"
)
# Traceback should be present in the log message
assert "Traceback" in error_records[0].message, (
"Umbrella log message must include a traceback"
)
# No CHANNEL_MSG_RECV event should have been dispatched
assert not any(e.type == EventType.CHANNEL_MSG_RECV for e in dispatcher.events)
@pytest.mark.asyncio
async def test_battery_short_frame_omits_storage_fields():
"""Short BATTERY frame must not silently yield zero used_kb/total_kb."""
dispatcher = _CapturingDispatcher()
reader = MessageReader(dispatcher)
# 3-byte BATTERY frame: type 0x0c + 2 level bytes (no storage tail).
# Pre-fix the `len(data) > 3` gate would have let any frame >= 4 bytes
# through, producing a BATTERY event with bogus zero used_kb/total_kb
# because io.BytesIO.read() returns short data without raising.
# Post-fix (`len(data) >= 11`) the storage fields are skipped entirely.
short_battery = bytearray.fromhex("0c8000")
await reader.handle_rx(short_battery)
battery_events = [e for e in dispatcher.events if e.type == EventType.BATTERY]
assert len(battery_events) == 1, (
f"Expected exactly one BATTERY event, got {len(battery_events)}"
)
payload = battery_events[0].payload
assert payload["level"] == 0x0080, f"Unexpected level: {payload['level']}"
assert "used_kb" not in payload, (
"Short BATTERY frame must not include used_kb (would be a silent zero)"
)
assert "total_kb" not in payload, (
"Short BATTERY frame must not include total_kb (would be a silent zero)"
)
@pytest.mark.asyncio
async def test_battery_too_short_for_level(caplog):
"""BATTERY frame shorter than 3 bytes must be dropped entirely (Option B).
A 1-byte frame (just the packet-type byte 0x0c, no level bytes) would cause
dbuf.read(2) to return b"" and int.from_bytes(b"", ...) to silently yield 0.
The fix adds an early return with a debug log.
"""
dispatcher = _CapturingDispatcher()
reader = MessageReader(dispatcher)
# 1-byte BATTERY frame: only the type byte, no level payload.
too_short = bytearray.fromhex("0c")
with caplog.at_level(logging.DEBUG, logger="meshcore"):
await reader.handle_rx(too_short)
battery_events = [e for e in dispatcher.events if e.type == EventType.BATTERY]
assert len(battery_events) == 0, (
"BATTERY frame shorter than 3 bytes must not dispatch an event"
)
debug_records = [
r for r in caplog.records if "BATTERY frame too short" in r.message
]
assert debug_records, "Expected a debug log about the short BATTERY frame"
@pytest.mark.asyncio
async def test_status_response_short_frame_skipped(caplog):
"""Short STATUS_RESPONSE push frame must be skipped, not parsed with bogus zeros."""
dispatcher = _CapturingDispatcher()
reader = MessageReader(dispatcher)
# 30-byte STATUS_RESPONSE push frame, well below the 60-byte minimum.
# First byte is the type (0x87 = PacketType.STATUS_RESPONSE), the rest
# is arbitrary filler. parse_status with offset=8 reads up through
# data[56:60], so anything < 60 bytes would yield short reads and
# silent zero values pre-fix.
short_status = bytearray([0x87] + [0xAA] * 29)
assert len(short_status) == 30
with caplog.at_level(logging.DEBUG, logger="meshcore"):
await reader.handle_rx(short_status)
status_events = [e for e in dispatcher.events if e.type == EventType.STATUS_RESPONSE]
assert len(status_events) == 0, (
"Short STATUS_RESPONSE push frame must not dispatch a parsed event"
)
assert any(
"STATUS_RESPONSE push frame too short" in r.message for r in caplog.records
), "Expected a debug log line for short STATUS_RESPONSE frames"
@pytest.mark.asyncio
async def test_parse_packet_payload_txt_type_decodes_high_bits():
"""txt_type must decode the high 6 bits of byte 4, not always be 0."""
from Crypto.Cipher import AES
from Crypto.Hash import HMAC, SHA256
from meshcore.meshcore_parser import MeshcorePacketParser
parser = MeshcorePacketParser()
parser.decrypt_channels = True
# Set up a synthetic channel with a known 16-byte AES key. Direct dict
# assignment matches how the parser stores channels (newChannel is async
# and serves the same purpose).
channel_secret = b"\x01" * 16
channel_hash_byte = 0xAB
parser.channels[0] = {
"channel_idx": 0,
"channel_name": "test",
"channel_hash": "ab",
"channel_secret": channel_secret,
}
# 16-byte plaintext (one AES block):
# bytes 0-3 = sender_timestamp (little-endian)
# byte 4 = (txt_type << 2) | attempt
# bytes 5-15 = message + null padding
# Pick txt_type=5, attempt=1 → byte 4 = (5 << 2) | 1 = 0x15.
# Pre-fix uncrypted[4:4] is empty so txt_type would be 0;
# post-fix uncrypted[4:5] yields 0x15 >> 2 = 5.
plaintext = b"\x00\x00\x00\x00\x15hello\x00\x00\x00\x00\x00\x00"
assert len(plaintext) == 16
encrypted = AES.new(channel_secret, AES.MODE_ECB).encrypt(plaintext)
# cipher_mac = first 2 bytes of HMAC-SHA256(channel_secret, encrypted)
h = HMAC.new(channel_secret, digestmod=SHA256)
h.update(encrypted)
cipher_mac = h.digest()[:2]
# pkt_payload layout: 1-byte chan_hash + 2-byte cipher_mac + ciphertext
pkt_payload = bytes([channel_hash_byte]) + cipher_mac + encrypted
# parsePacketPayload expects the full payload buffer:
# header byte (route_type=1 DIRECT, payload_type=5 channel, ver=0)
# path_byte (path_len=0, path_hash_size=1) → 0x00
# pkt_payload
header = 0x15 # route_type=1, payload_type=5, payload_ver=0
path_byte = 0x00
payload = bytes([header, path_byte]) + pkt_payload
log_data = await parser.parsePacketPayload(payload, log_data={})
assert log_data["payload_type"] == 0x05
assert "txt_type" in log_data, (
f"txt_type missing from log_data — channel decrypt path was not reached. "
f"log_data keys: {list(log_data.keys())}"
)
assert log_data["txt_type"] == 5, (
f"Expected txt_type=5, got {log_data['txt_type']}"
)
assert log_data["attempt"] == 1, (
f"Expected attempt=1, got {log_data['attempt']}"
)

View File

@@ -0,0 +1,238 @@
"""
Verification tests for transport symmetry fixes.
Covers: send symmetry across transports, serial disconnect callback on
transport-lost, serial connect timeout, oversize-frame return, BLE
disconnect-callback re-registration, BLE pairing failure re-raise,
TCP counter per frame not per segment.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from meshcore.tcp_cx import TCPConnection
from meshcore.serial_cx import SerialConnection
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class RecordingReader:
"""Minimal reader mock that records dispatched frames."""
def __init__(self):
self.frames = []
async def handle_rx(self, data):
self.frames.append(bytes(data))
# ---------------------------------------------------------------------------
# TCP send() wraps transport.write in try/except
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tcp_send_write_error_fires_disconnect():
"""TCP: OSError during transport.write fires _disconnect_callback."""
cx = TCPConnection("127.0.0.1", 5000)
cb = AsyncMock()
cx.set_disconnect_callback(cb)
mock_transport = MagicMock()
mock_transport.write.side_effect = OSError("Broken pipe")
cx.transport = mock_transport
cx._send_count = 0
cx._receive_count = 0
await cx.send(b"\x01\x02\x03")
cb.assert_awaited_once()
reason = cb.call_args[0][0]
assert "tcp_write_failed" in reason
# ---------------------------------------------------------------------------
# Serial send() fires disconnect on transport-lost and write error
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_serial_send_no_transport_fires_disconnect():
"""Serial: send() on None transport fires _disconnect_callback ."""
cx = SerialConnection("/dev/null", 115200)
cb = AsyncMock()
cx.set_disconnect_callback(cb)
cx.transport = None
await cx.send(b"\x01")
cb.assert_awaited_once()
reason = cb.call_args[0][0]
assert reason == "serial_transport_lost"
@pytest.mark.asyncio
async def test_serial_send_write_error_fires_disconnect():
"""Serial: OSError during transport.write fires _disconnect_callback."""
cx = SerialConnection("/dev/null", 115200)
cb = AsyncMock()
cx.set_disconnect_callback(cb)
mock_transport = MagicMock()
mock_transport.write.side_effect = OSError("Device not configured")
cx.transport = mock_transport
await cx.send(b"\x01")
cb.assert_awaited_once()
reason = cb.call_args[0][0]
assert "serial_write_failed" in reason
# ---------------------------------------------------------------------------
# BLE send() fires disconnect on transport-lost and write error
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_ble_send_no_client_fires_disconnect():
"""BLE: send() with no client fires _disconnect_callback."""
# Can't import BLEConnection directly if bleak isn't installed,
# so test via dynamic import with a guard.
try:
from meshcore.ble_cx import BLEConnection
except ImportError:
pytest.skip("bleak not installed")
# BLEConnection.__init__ checks BLEAK_AVAILABLE; patch it
with patch("meshcore.ble_cx.BLEAK_AVAILABLE", True), \
patch("meshcore.ble_cx.BleakClient", MagicMock()):
cx = BLEConnection.__new__(BLEConnection)
cx.client = None
cx._user_provided_client = None
cx._user_provided_address = None
cx._user_provided_device = None
cx.address = None
cx.device = None
cx.pin = None
cx.rx_char = None
cb = AsyncMock()
cx._disconnect_callback = cb
result = await cx.send(b"\x01")
assert result is False
cb.assert_awaited_once()
reason = cb.call_args[0][0]
assert reason == "ble_transport_lost"
# ---------------------------------------------------------------------------
# Serial connect() times out if connection_made never fires
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_serial_connect_timeout():
"""Serial: connect() raises TimeoutError if connection_made never fires."""
cx = SerialConnection("/dev/null", 115200)
# Mock create_serial_connection to do nothing (never fires connection_made)
async def mock_create(*args, **kwargs):
return (MagicMock(), MagicMock())
with patch("meshcore.serial_cx.serial_asyncio.create_serial_connection",
side_effect=mock_create):
with pytest.raises(asyncio.TimeoutError):
await cx.connect(timeout=0.1)
# ---------------------------------------------------------------------------
# Oversize frame resets state and returns without dispatch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tcp_oversize_frame_empty_data_returns():
"""TCP: oversize header with no trailing data returns without dispatch."""
cx = TCPConnection("127.0.0.1", 5000)
reader = RecordingReader()
cx.set_reader(reader)
# Build a frame header with size > 300 and no payload data after header
# Header: 0x3e + 2-byte LE size (e.g. 500 = 0x01F4)
header = b"\x3e" + (500).to_bytes(2, "little")
cx.handle_rx(header)
await asyncio.sleep(0)
# No frames should be dispatched, and state should be reset
assert reader.frames == []
assert cx.header == b""
assert cx.inframe == b""
assert cx.frame_expected_size == 0
@pytest.mark.asyncio
async def test_serial_oversize_frame_empty_data_returns():
"""Serial: oversize header with no trailing data returns without dispatch."""
cx = SerialConnection("/dev/null", 115200)
reader = RecordingReader()
cx.set_reader(reader)
header = b"\x3e" + (500).to_bytes(2, "little")
cx.handle_rx(header)
await asyncio.sleep(0)
assert reader.frames == []
assert cx.header == b""
assert cx.inframe == b""
assert cx.frame_expected_size == 0
# ---------------------------------------------------------------------------
# TCP receive counter increments per MeshCore frame, not per TCP segment
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_tcp_receive_count_per_frame_not_per_segment():
"""TCP: _receive_count increments per completed frame, not per data_received call."""
cx = TCPConnection("127.0.0.1", 5000)
reader = RecordingReader()
cx.set_reader(reader)
cx._receive_count = 0
# Build a 4-byte payload frame
payload = b"\xAA\xBB\xCC\xDD"
frame = b"\x3e" + len(payload).to_bytes(2, "little") + payload
# Split the frame into 3 TCP segments (simulating fragmentation)
protocol = TCPConnection.MCClientProtocol(cx)
protocol.data_received(frame[:2]) # partial header
protocol.data_received(frame[2:5]) # rest of header + 2 bytes payload
protocol.data_received(frame[5:]) # remaining payload
await asyncio.sleep(0)
# 3 data_received calls but only 1 completed frame
assert cx._receive_count == 1
assert reader.frames == [payload]
@pytest.mark.asyncio
async def test_tcp_multiple_frames_count_correctly():
"""TCP: two complete frames in separate segments → _receive_count == 2."""
cx = TCPConnection("127.0.0.1", 5000)
reader = RecordingReader()
cx.set_reader(reader)
cx._receive_count = 0
payload1 = b"\x01\x02"
frame1 = b"\x3e" + len(payload1).to_bytes(2, "little") + payload1
payload2 = b"\x03\x04\x05"
frame2 = b"\x3e" + len(payload2).to_bytes(2, "little") + payload2
protocol = TCPConnection.MCClientProtocol(cx)
protocol.data_received(frame1)
protocol.data_received(frame2)
await asyncio.sleep(0)
assert cx._receive_count == 2
assert reader.frames == [payload1, payload2]