mirror of
https://github.com/meshcore-dev/meshcore_py.git
synced 2026-06-11 11:56:18 +00:00
Merge branch 'main' into fix/standalone-bugs-and-cleanup
This commit is contained in:
@@ -51,6 +51,14 @@ class BLEConnection:
|
||||
self.pin = pin
|
||||
self.rx_char = None
|
||||
self._disconnect_callback = None
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def _spawn_background(self, coro) -> asyncio.Task:
|
||||
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return task
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
@@ -116,9 +124,12 @@ class BLEConnection:
|
||||
await self.client.pair()
|
||||
logger.info("BLE pairing successful")
|
||||
except Exception as e:
|
||||
logger.warning(f"BLE pairing failed: {e}")
|
||||
# Don't fail the connection if pairing fails, as the device
|
||||
# might already be paired or not require pairing
|
||||
logger.error(f"BLE pairing failed: {e}")
|
||||
# A failed pairing leaves the transport in a half-usable
|
||||
# state — re-raise so the caller gets a clean failure
|
||||
# instead of a silently degraded connection.
|
||||
await self.client.disconnect()
|
||||
raise
|
||||
|
||||
except BleakDeviceNotFoundError:
|
||||
return None
|
||||
@@ -154,8 +165,19 @@ class BLEConnection:
|
||||
self.client = self._user_provided_client
|
||||
self.device = self._user_provided_device
|
||||
|
||||
# Re-register disconnect callback on the reset client so subsequent
|
||||
# disconnects after a reconnect cycle are still detected.
|
||||
if self.client is not None and hasattr(self.client, 'set_disconnected_callback'):
|
||||
try:
|
||||
self.client.set_disconnected_callback(self.handle_disconnect)
|
||||
except Exception:
|
||||
# set_disconnected_callback may not be available on all bleak
|
||||
# versions; the next connect() call will re-create the client
|
||||
# with the callback anyway.
|
||||
pass
|
||||
|
||||
if self._disconnect_callback:
|
||||
asyncio.create_task(self._disconnect_callback("ble_disconnect"))
|
||||
self._spawn_background(self._disconnect_callback("ble_disconnect"))
|
||||
|
||||
def set_disconnect_callback(self, callback):
|
||||
"""Set callback to handle disconnections."""
|
||||
@@ -166,16 +188,24 @@ class BLEConnection:
|
||||
|
||||
def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray):
|
||||
if self.reader is not None:
|
||||
asyncio.create_task(self.reader.handle_rx(data))
|
||||
self._spawn_background(self.reader.handle_rx(data))
|
||||
|
||||
async def send(self, data):
|
||||
if not self.client:
|
||||
logger.error("Client is not connected")
|
||||
if self._disconnect_callback:
|
||||
await self._disconnect_callback("ble_transport_lost")
|
||||
return False
|
||||
if not self.rx_char:
|
||||
logger.error("RX characteristic not found")
|
||||
return False
|
||||
await self.client.write_gatt_char(self.rx_char, bytes(data), response=True)
|
||||
try:
|
||||
await self.client.write_gatt_char(self.rx_char, bytes(data), response=True)
|
||||
except Exception as exc:
|
||||
logger.warning(f"BLE write failed: {exc}")
|
||||
if self._disconnect_callback:
|
||||
await self._disconnect_callback(f"ble_write_failed: {exc}")
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from the BLE device."""
|
||||
|
||||
@@ -58,17 +58,32 @@ def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes
|
||||
|
||||
|
||||
class CommandHandlerBase:
|
||||
"""Base class for command handlers.
|
||||
|
||||
.. note::
|
||||
The internal ``asyncio.Lock`` is created lazily on first access
|
||||
so that it binds to the correct running event loop (required for
|
||||
Python 3.9/3.10 compatibility).
|
||||
"""
|
||||
|
||||
DEFAULT_TIMEOUT = 15.0
|
||||
|
||||
def __init__(self, default_timeout: Optional[float] = None):
|
||||
self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None
|
||||
self._reader: Optional[MessageReader] = None
|
||||
self.dispatcher: Optional[EventDispatcher] = None
|
||||
self._mesh_request_lock = asyncio.Lock()
|
||||
self.__mesh_request_lock: Optional[asyncio.Lock] = None
|
||||
self.default_timeout = (
|
||||
default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT
|
||||
)
|
||||
|
||||
@property
|
||||
def _mesh_request_lock(self) -> asyncio.Lock:
|
||||
"""Lazy-init lock so it binds to the running loop, not import-time."""
|
||||
if self.__mesh_request_lock is None:
|
||||
self.__mesh_request_lock = asyncio.Lock()
|
||||
return self.__mesh_request_lock
|
||||
|
||||
def set_connection(self, connection: Any) -> None:
|
||||
async def sender(data: bytes) -> None:
|
||||
await connection.send(data)
|
||||
@@ -90,6 +105,14 @@ class CommandHandlerBase:
|
||||
expected_events: Optional[Union[EventType, List[EventType]]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Event:
|
||||
"""Wait for the first of *expected_events* to arrive.
|
||||
|
||||
Returns the first matched ``Event``. When ``EventType.ERROR`` is
|
||||
among the expected types, the caller **must** check
|
||||
``result.is_error()`` before accessing command-specific payload
|
||||
keys — an ERROR payload is ``{"reason": "..."}`` and will
|
||||
``KeyError`` on any other key.
|
||||
"""
|
||||
try:
|
||||
# Convert single event to list if needed
|
||||
if not isinstance(expected_events, list):
|
||||
@@ -129,9 +152,6 @@ class CommandHandlerBase:
|
||||
logger.debug(f"Command error: {e}")
|
||||
return Event(EventType.ERROR, {"error": str(e)})
|
||||
|
||||
return Event(EventType.ERROR, {})
|
||||
|
||||
|
||||
async def send(
|
||||
self,
|
||||
data: bytes,
|
||||
@@ -151,7 +171,14 @@ class CommandHandlerBase:
|
||||
timeout: Timeout in seconds, or None to use default_timeout
|
||||
|
||||
Returns:
|
||||
Event: The full event object that was received in response to the command
|
||||
Event: The full event object that was received in response to
|
||||
the command.
|
||||
|
||||
Important:
|
||||
When ``EventType.ERROR`` is included in *expected_events*, the
|
||||
returned event may be an error response. Callers **must**
|
||||
check ``result.is_error()`` before accessing command-specific
|
||||
payload keys to avoid ``KeyError``.
|
||||
"""
|
||||
if not self.dispatcher:
|
||||
raise RuntimeError("Dispatcher not set, cannot send commands")
|
||||
@@ -170,7 +197,7 @@ class CommandHandlerBase:
|
||||
futures: List[asyncio.Future] = []
|
||||
subscriptions = []
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
for event_type in expected_events:
|
||||
future = loop.create_future()
|
||||
|
||||
@@ -279,6 +306,7 @@ class CommandHandlerBase:
|
||||
contact = self._get_contact_by_prefix(dst_bytes.hex()) # need a contact for return path
|
||||
if contact is None:
|
||||
logger.error("No contact found")
|
||||
return Event(EventType.ERROR, {"reason": "contact_not_found"})
|
||||
|
||||
zero_hop = False
|
||||
if contact["out_path_len"] == -1:
|
||||
|
||||
@@ -191,6 +191,24 @@ class ContactCommands(CommandHandlerBase):
|
||||
data = b"\x3B"
|
||||
return await self.send(data, [EventType.AUTOADD_CONFIG, EventType.ERROR])
|
||||
|
||||
async def get_contact_by_key(self, pubkey: bytes) -> Event:
|
||||
"""N09: Retrieve a single contact by its public key (CMD 30).
|
||||
|
||||
Args:
|
||||
pubkey: 32-byte public key of the contact.
|
||||
|
||||
Returns:
|
||||
Event with the contact data (same format as CONTACT/NEXT_CONTACT),
|
||||
or ERROR if not found.
|
||||
"""
|
||||
if not isinstance(pubkey, (bytes, bytearray)):
|
||||
raise TypeError("pubkey must be bytes-like")
|
||||
# Truncate or pad to 32 bytes
|
||||
key_bytes = bytes(pubkey[:32])
|
||||
logger.debug(f"Getting contact by key: {key_bytes.hex()}")
|
||||
data = b"\x1e" + key_bytes
|
||||
return await self.send(data, [EventType.NEXT_CONTACT, EventType.ERROR])
|
||||
|
||||
async def get_advert_path(self, key: DestinationType) -> Event:
|
||||
key_bytes = _validate_destination(key, prefix_length=32)
|
||||
logger.debug(f"getting advert path for: {key} {key_bytes.hex()}")
|
||||
|
||||
@@ -4,6 +4,7 @@ from hashlib import sha256
|
||||
from typing import Optional
|
||||
|
||||
from ..events import Event, EventType
|
||||
from ..packets import CommandType
|
||||
from .base import CommandHandlerBase, DestinationType, _validate_destination
|
||||
|
||||
logger = logging.getLogger("meshcore")
|
||||
@@ -13,7 +14,7 @@ class DeviceCommands(CommandHandlerBase):
|
||||
async def send_appstart(self) -> Event:
|
||||
logger.debug("Sending appstart command")
|
||||
b1 = bytearray(b"\x01\x03 mccli")
|
||||
return await self.send(b1, [EventType.SELF_INFO])
|
||||
return await self.send(b1, [EventType.SELF_INFO, EventType.ERROR])
|
||||
|
||||
async def send_device_query(self) -> Event:
|
||||
logger.debug("Sending device query command")
|
||||
@@ -129,32 +130,50 @@ class DeviceCommands(CommandHandlerBase):
|
||||
return await self.send(data, [EventType.OK, EventType.ERROR])
|
||||
|
||||
async def set_telemetry_mode_base(self, telemetry_mode_base: int) -> Event:
|
||||
infos = (await self.send_appstart()).payload
|
||||
result = await self.send_appstart()
|
||||
if result.is_error():
|
||||
return result
|
||||
infos = result.payload
|
||||
infos["telemetry_mode_base"] = telemetry_mode_base
|
||||
return await self.set_other_params_from_infos(infos)
|
||||
|
||||
async def set_telemetry_mode_loc(self, telemetry_mode_loc: int) -> Event:
|
||||
infos = (await self.send_appstart()).payload
|
||||
result = await self.send_appstart()
|
||||
if result.is_error():
|
||||
return result
|
||||
infos = result.payload
|
||||
infos["telemetry_mode_loc"] = telemetry_mode_loc
|
||||
return await self.set_other_params_from_infos(infos)
|
||||
|
||||
async def set_telemetry_mode_env(self, telemetry_mode_env: int) -> Event:
|
||||
infos = (await self.send_appstart()).payload
|
||||
result = await self.send_appstart()
|
||||
if result.is_error():
|
||||
return result
|
||||
infos = result.payload
|
||||
infos["telemetry_mode_env"] = telemetry_mode_env
|
||||
return await self.set_other_params_from_infos(infos)
|
||||
|
||||
async def set_manual_add_contacts(self, manual_add_contacts: bool) -> Event:
|
||||
infos = (await self.send_appstart()).payload
|
||||
result = await self.send_appstart()
|
||||
if result.is_error():
|
||||
return result
|
||||
infos = result.payload
|
||||
infos["manual_add_contacts"] = manual_add_contacts
|
||||
return await self.set_other_params_from_infos(infos)
|
||||
|
||||
async def set_advert_loc_policy(self, advert_loc_policy: int) -> Event:
|
||||
infos = (await self.send_appstart()).payload
|
||||
result = await self.send_appstart()
|
||||
if result.is_error():
|
||||
return result
|
||||
infos = result.payload
|
||||
infos["adv_loc_policy"] = advert_loc_policy
|
||||
return await self.set_other_params_from_infos(infos)
|
||||
|
||||
async def set_multi_acks(self, multi_acks: int) -> Event:
|
||||
infos = (await self.send_appstart()).payload
|
||||
result = await self.send_appstart()
|
||||
if result.is_error():
|
||||
return result
|
||||
infos = result.payload
|
||||
infos["multi_acks"] = multi_acks
|
||||
return await self.set_other_params_from_infos(infos)
|
||||
|
||||
@@ -273,20 +292,89 @@ class DeviceCommands(CommandHandlerBase):
|
||||
|
||||
return await self.sign_finish(timeout=timeout, data_size=len(data))
|
||||
|
||||
async def has_connection(self) -> Event:
|
||||
"""N09: Check if the device has an active connection (CMD 28).
|
||||
|
||||
Returns:
|
||||
Event with a 1-byte response indicating connection status,
|
||||
or ERROR.
|
||||
"""
|
||||
logger.debug("Checking device connection status")
|
||||
return await self.send(b"\x1c", [EventType.OK, EventType.ERROR])
|
||||
|
||||
async def get_tuning(self) -> Event:
|
||||
"""N03/N09: Request current tuning parameters (CMD_GET_TUNING_PARAMS = 43).
|
||||
|
||||
Firmware responds with RESP_CODE_TUNING_PARAMS (23): 9 bytes containing
|
||||
rx_delay (4 bytes LE) and airtime_factor (4 bytes LE).
|
||||
|
||||
Returns:
|
||||
Event of type TUNING_PARAMS with rx_delay and airtime_factor,
|
||||
or ERROR.
|
||||
"""
|
||||
logger.debug("Getting tuning parameters")
|
||||
return await self.send(b"\x2b", [EventType.TUNING_PARAMS, EventType.ERROR])
|
||||
|
||||
async def request_factory_reset(self) -> str:
|
||||
"""N09: Request a factory reset token (step 1 of 2).
|
||||
|
||||
This method returns a confirmation token string. Pass it to
|
||||
``confirm_factory_reset(token)`` to actually execute the reset.
|
||||
The two-step pattern is a Python-side safety measure; the firmware
|
||||
itself has no token verification.
|
||||
|
||||
Returns:
|
||||
A confirmation token string to pass to confirm_factory_reset().
|
||||
"""
|
||||
import secrets
|
||||
token = secrets.token_hex(8)
|
||||
logger.warning(
|
||||
"Factory reset requested. Call confirm_factory_reset('%s') to proceed. "
|
||||
"This will ERASE ALL DATA on the device.", token
|
||||
)
|
||||
# Store the token on the instance for validation
|
||||
self._factory_reset_token = token
|
||||
return token
|
||||
|
||||
async def confirm_factory_reset(self, token: str) -> Event:
|
||||
"""N09: Execute factory reset after token confirmation (step 2 of 2).
|
||||
|
||||
Args:
|
||||
token: The token returned by request_factory_reset().
|
||||
|
||||
Returns:
|
||||
Event with OK or ERROR.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token does not match.
|
||||
"""
|
||||
expected = getattr(self, "_factory_reset_token", None)
|
||||
if expected is None or token != expected:
|
||||
raise ValueError(
|
||||
"Invalid or expired factory reset token. "
|
||||
"Call request_factory_reset() first."
|
||||
)
|
||||
self._factory_reset_token = None # Consume the token
|
||||
logger.warning("Executing factory reset — all device data will be erased")
|
||||
return await self.send(b"\x33", [EventType.OK, EventType.ERROR])
|
||||
|
||||
async def get_stats_core(self) -> Event:
|
||||
logger.debug("Getting core statistics")
|
||||
# CMD_GET_STATS (56) + STATS_TYPE_CORE (0)
|
||||
return await self.send(b"\x38\x00", [EventType.STATS_CORE, EventType.ERROR])
|
||||
# R04: Use CommandType enum instead of literal bytes
|
||||
cmd = bytes([CommandType.GET_STATS.value, 0x00]) # GET_STATS + STATS_TYPE_CORE
|
||||
return await self.send(cmd, [EventType.STATS_CORE, EventType.ERROR])
|
||||
|
||||
async def get_stats_radio(self) -> Event:
|
||||
logger.debug("Getting radio statistics")
|
||||
# CMD_GET_STATS (56) + STATS_TYPE_RADIO (1)
|
||||
return await self.send(b"\x38\x01", [EventType.STATS_RADIO, EventType.ERROR])
|
||||
# R04: Use CommandType enum instead of literal bytes
|
||||
cmd = bytes([CommandType.GET_STATS.value, 0x01]) # GET_STATS + STATS_TYPE_RADIO
|
||||
return await self.send(cmd, [EventType.STATS_RADIO, EventType.ERROR])
|
||||
|
||||
async def get_stats_packets(self) -> Event:
|
||||
logger.debug("Getting packet statistics")
|
||||
# CMD_GET_STATS (56) + STATS_TYPE_PACKETS (2)
|
||||
return await self.send(b"\x38\x02", [EventType.STATS_PACKETS, EventType.ERROR])
|
||||
# R04: Use CommandType enum instead of literal bytes
|
||||
cmd = bytes([CommandType.GET_STATS.value, 0x02]) # GET_STATS + STATS_TYPE_PACKETS
|
||||
return await self.send(cmd, [EventType.STATS_PACKETS, EventType.ERROR])
|
||||
|
||||
async def get_allowed_repeat_freq(self) -> Event:
|
||||
logger.debug("Getting allowed repeat freqs")
|
||||
|
||||
@@ -144,8 +144,12 @@ class MessagingCommands(CommandHandlerBase):
|
||||
logger.info(f"Retry sending msg: {attempts + 1}")
|
||||
|
||||
result = await self.send_msg(dst, msg, timestamp, attempt=attempts)
|
||||
if result.type == EventType.ERROR:
|
||||
logger.error(f"⚠️ Failed to send message: {result.payload}")
|
||||
if result.is_error():
|
||||
logger.error(f"Failed to send message: {result.payload}")
|
||||
attempts += 1
|
||||
if flood:
|
||||
flood_attempts += 1
|
||||
continue
|
||||
|
||||
exp_ack = result.payload["expected_ack"].hex()
|
||||
timeout = result.payload["suggested_timeout"] / 1000 * 1.2 if timeout==0 else timeout
|
||||
@@ -255,7 +259,7 @@ class MessagingCommands(CommandHandlerBase):
|
||||
elif path_hash_len == 8 :
|
||||
flags = 3
|
||||
else :
|
||||
logger.error(f"Invalid path format: {e}")
|
||||
logger.error(f"Invalid path format: unknown path_hash_len {path_hash_len}")
|
||||
return Event(EventType.ERROR, {"reason": "invalid_path_format"})
|
||||
else:
|
||||
flags = 0
|
||||
@@ -291,12 +295,34 @@ class MessagingCommands(CommandHandlerBase):
|
||||
cmd_data.append(flags)
|
||||
cmd_data.extend(path_bytes)
|
||||
|
||||
# N05: Firmware requires strict len > 10 (MyMesh.cpp:1620).
|
||||
# When path is empty, cmd(1)+tag(4)+auth(4)+flags(1) = 10 bytes exactly,
|
||||
# which is silently rejected. Pad with one zero byte to reach 11.
|
||||
if len(cmd_data) <= 10:
|
||||
cmd_data.append(0x00)
|
||||
|
||||
logger.debug(
|
||||
f"Sending trace: tag={tag}, auth={auth_code}, flags={flags}, path={path_bytes.hex()}"
|
||||
)
|
||||
|
||||
return await self.send(cmd_data, [EventType.MSG_SENT, EventType.ERROR])
|
||||
|
||||
async def send_raw_data(self, payload: bytes) -> Event:
|
||||
"""N09: Send raw data via CMD_SEND_RAW_DATA (25).
|
||||
|
||||
Sends an arbitrary payload through the mesh network.
|
||||
|
||||
Args:
|
||||
payload: Raw bytes to send.
|
||||
|
||||
Returns:
|
||||
Event with MSG_SENT or ERROR.
|
||||
"""
|
||||
if not isinstance(payload, (bytes, bytearray)):
|
||||
raise TypeError("payload must be bytes-like")
|
||||
data = b"\x19" + bytes(payload)
|
||||
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
|
||||
|
||||
async def set_flood_scope(self, scope):
|
||||
if scope is None:
|
||||
logger.debug(f"Resetting scope")
|
||||
|
||||
@@ -4,14 +4,23 @@ Connection manager that orchestrates reconnection logic for any connection type.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Any, Callable, Protocol
|
||||
from typing import Optional, Any, Awaitable, Callable, Protocol
|
||||
from .events import Event, EventType
|
||||
|
||||
logger = logging.getLogger("meshcore")
|
||||
|
||||
|
||||
class ConnectionProtocol(Protocol):
|
||||
"""Protocol defining the interface that connection classes must implement."""
|
||||
"""Protocol defining the interface that connection classes must implement.
|
||||
|
||||
Return contract for connect():
|
||||
- On success: return a truthy value (typically an address string)
|
||||
that identifies the connection. This value is included in the
|
||||
CONNECTED event payload as ``connection_info``.
|
||||
- On failure: return ``None`` (soft failure — triggers a retry in
|
||||
``_attempt_reconnect``) **or** raise an exception (hard failure —
|
||||
also triggers a retry, logged as an error).
|
||||
"""
|
||||
|
||||
async def connect(self) -> Optional[Any]:
|
||||
"""Connect and return connection info, or None if failed."""
|
||||
@@ -39,11 +48,13 @@ class ConnectionManager:
|
||||
event_dispatcher=None,
|
||||
auto_reconnect: bool = False,
|
||||
max_reconnect_attempts: int = 3,
|
||||
reconnect_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||
):
|
||||
self.connection = connection
|
||||
self.event_dispatcher = event_dispatcher
|
||||
self.auto_reconnect = auto_reconnect
|
||||
self.max_reconnect_attempts = max_reconnect_attempts
|
||||
self._reconnect_callback = reconnect_callback
|
||||
|
||||
self._reconnect_attempts = 0
|
||||
self._is_connected = False
|
||||
@@ -109,45 +120,51 @@ class ConnectionManager:
|
||||
)
|
||||
|
||||
async def _attempt_reconnect(self):
|
||||
"""Attempt to reconnect with flat delay."""
|
||||
logger.debug(
|
||||
f"Attempting reconnection ({self._reconnect_attempts + 1}/{self.max_reconnect_attempts})"
|
||||
)
|
||||
self._reconnect_attempts += 1
|
||||
"""Attempt to reconnect using an iterative loop.
|
||||
|
||||
# Flat 1 second delay for all attempts
|
||||
await asyncio.sleep(1)
|
||||
Runs as a single persistent task for the entire reconnect session.
|
||||
Previous implementation used tail-recursion via create_task which
|
||||
orphaned the running task reference — disconnect() could only cancel
|
||||
the newest pointer, leaving earlier attempts in flight (F03).
|
||||
"""
|
||||
while self._reconnect_attempts < self.max_reconnect_attempts:
|
||||
self._reconnect_attempts += 1
|
||||
logger.debug(
|
||||
f"Attempting reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})"
|
||||
)
|
||||
|
||||
# Flat 1 second delay for all attempts
|
||||
await asyncio.sleep(1)
|
||||
|
||||
try:
|
||||
result = await self.connection.connect()
|
||||
if result is not None:
|
||||
self._is_connected = True
|
||||
self._reconnect_attempts = 0
|
||||
|
||||
# Invoke reconnect callback (e.g. send_appstart) if provided
|
||||
if self._reconnect_callback is not None:
|
||||
try:
|
||||
await self._reconnect_callback()
|
||||
except Exception as cb_err:
|
||||
logger.warning(
|
||||
f"Reconnect callback failed: {cb_err}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self.connection.connect()
|
||||
if result is not None:
|
||||
self._is_connected = True
|
||||
self._reconnect_attempts = 0
|
||||
await self._emit_event(
|
||||
EventType.CONNECTED,
|
||||
{"connection_info": result, "reconnected": True},
|
||||
)
|
||||
logger.debug("Reconnected successfully")
|
||||
else:
|
||||
# Reconnection failed, try again if we haven't exceeded max attempts
|
||||
if self._reconnect_attempts < self.max_reconnect_attempts:
|
||||
self._reconnect_task = asyncio.create_task(
|
||||
self._attempt_reconnect()
|
||||
)
|
||||
else:
|
||||
await self._emit_event(
|
||||
EventType.DISCONNECTED,
|
||||
{"reason": "reconnect_failed", "max_attempts_exceeded": True},
|
||||
EventType.CONNECTED,
|
||||
{"connection_info": result, "reconnected": True},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Reconnection attempt failed: {e}")
|
||||
if self._reconnect_attempts < self.max_reconnect_attempts:
|
||||
self._reconnect_task = asyncio.create_task(self._attempt_reconnect())
|
||||
else:
|
||||
await self._emit_event(
|
||||
EventType.DISCONNECTED,
|
||||
{"reason": f"reconnect_error: {e}", "max_attempts_exceeded": True},
|
||||
)
|
||||
logger.debug("Reconnected successfully")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"Reconnection attempt failed: {e}")
|
||||
|
||||
# All attempts exhausted
|
||||
await self._emit_event(
|
||||
EventType.DISCONNECTED,
|
||||
{"reason": "reconnect_failed", "max_attempts_exceeded": True},
|
||||
)
|
||||
|
||||
async def _emit_event(self, event_type: EventType, payload: dict):
|
||||
"""Emit connection events if dispatcher is available."""
|
||||
|
||||
@@ -49,6 +49,9 @@ class EventType(Enum):
|
||||
PATH_RESPONSE = "path_response"
|
||||
PRIVATE_KEY = "private_key"
|
||||
DISABLED = "disabled"
|
||||
CONTACT_DELETED = "contact_deleted"
|
||||
CONTACTS_FULL = "contacts_full"
|
||||
TUNING_PARAMS = "tuning_params"
|
||||
CONTROL_DATA = "control_data"
|
||||
DISCOVER_RESPONSE = "discover_response"
|
||||
NEIGHBOURS_RESPONSE = "neighbours_response"
|
||||
@@ -104,6 +107,17 @@ class Event:
|
||||
if kwargs:
|
||||
self.attributes.update(kwargs)
|
||||
|
||||
def is_error(self) -> bool:
|
||||
"""Return True if this event represents an error response.
|
||||
|
||||
Callers that include ``EventType.ERROR`` in their expected-events
|
||||
list **must** check ``result.is_error()`` (or ``result.type ==
|
||||
EventType.ERROR``) before accessing keyed payload fields, because
|
||||
an ERROR payload contains ``{"reason": "..."}`` — not the
|
||||
command-specific keys the caller expects on the happy path.
|
||||
"""
|
||||
return self.type == EventType.ERROR
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
Create a copy of the event.
|
||||
@@ -129,11 +143,28 @@ class Subscription:
|
||||
|
||||
|
||||
class EventDispatcher:
|
||||
"""Event dispatch engine.
|
||||
|
||||
.. note::
|
||||
``start()`` must be called before dispatching or processing events.
|
||||
The internal ``asyncio.Queue`` is created lazily inside ``start()``
|
||||
so that it binds to the correct running event loop (required for
|
||||
Python 3.9/3.10 compatibility).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.queue: asyncio.Queue[Event] = asyncio.Queue()
|
||||
self.queue: Optional[asyncio.Queue[Event]] = None
|
||||
self.subscriptions: List[Subscription] = []
|
||||
self.running = False
|
||||
self._task = None
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def _spawn_background(self, coro) -> asyncio.Task:
|
||||
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return task
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
@@ -166,6 +197,10 @@ class EventDispatcher:
|
||||
self.subscriptions.remove(subscription)
|
||||
|
||||
async def dispatch(self, event: Event):
|
||||
if self.queue is None:
|
||||
raise RuntimeError(
|
||||
"EventDispatcher.start() must be called before dispatching events"
|
||||
)
|
||||
await self.queue.put(event)
|
||||
|
||||
async def _process_events(self):
|
||||
@@ -197,7 +232,7 @@ class EventDispatcher:
|
||||
# returns - avoids the race where create_task schedules the callback after
|
||||
# the waiter has already timed out with done=set().
|
||||
if asyncio.iscoroutinefunction(subscription.callback):
|
||||
asyncio.create_task(self._execute_callback(subscription.callback, event.clone()))
|
||||
self._spawn_background(self._execute_callback(subscription.callback, event.clone()))
|
||||
else:
|
||||
try:
|
||||
subscription.callback(event.clone())
|
||||
@@ -220,6 +255,8 @@ class EventDispatcher:
|
||||
|
||||
async def start(self):
|
||||
if not self.running:
|
||||
if self.queue is None:
|
||||
self.queue = asyncio.Queue()
|
||||
self.running = True
|
||||
self._task = asyncio.create_task(self._process_events())
|
||||
|
||||
@@ -227,7 +264,12 @@ class EventDispatcher:
|
||||
if self.running:
|
||||
self.running = False
|
||||
if self._task:
|
||||
await self.queue.join()
|
||||
if self.queue is not None:
|
||||
await self.queue.join()
|
||||
# Wait for any in-flight async callbacks to complete before
|
||||
# tearing down (F07: task_done fires before callbacks finish).
|
||||
if self._background_tasks:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
|
||||
@@ -28,10 +28,17 @@ class MeshCore:
|
||||
auto_reconnect: bool = False,
|
||||
max_reconnect_attempts: int = 3,
|
||||
):
|
||||
# Wrap connection with ConnectionManager
|
||||
# Wrap connection with ConnectionManager.
|
||||
# The reconnect callback ensures send_appstart() runs after every
|
||||
# transport-level reconnect, which is required by firmware to
|
||||
# initialize the session (F02).
|
||||
self.dispatcher = EventDispatcher()
|
||||
self.connection_manager = ConnectionManager(
|
||||
cx, self.dispatcher, auto_reconnect, max_reconnect_attempts
|
||||
cx,
|
||||
self.dispatcher,
|
||||
auto_reconnect,
|
||||
max_reconnect_attempts,
|
||||
reconnect_callback=self._on_reconnect,
|
||||
)
|
||||
self.cx = self.connection_manager # For backward compatibility
|
||||
|
||||
@@ -174,6 +181,15 @@ class MeshCore:
|
||||
return None
|
||||
return mc
|
||||
|
||||
async def _on_reconnect(self):
|
||||
"""Callback invoked by ConnectionManager after a successful reconnect.
|
||||
|
||||
Firmware requires CMD_APP_START after every transport-level connection
|
||||
to initialize the session. MeshCore.connect() does this on the initial
|
||||
connection; this callback ensures it also happens on reconnects (F02).
|
||||
"""
|
||||
await self.commands.send_appstart()
|
||||
|
||||
async def connect(self):
|
||||
await self.dispatcher.start()
|
||||
result = await self.connection_manager.connect()
|
||||
|
||||
@@ -42,6 +42,28 @@ class MeshcorePacketParser:
|
||||
Returns :
|
||||
completed log_data
|
||||
"""
|
||||
# Minimum viable payload is 2 bytes (1 header + 1 path_byte) for a
|
||||
# direct route. Anything shorter is provably broken — for example,
|
||||
# the LOG_DATA branch in reader.py only requires `len(data) > 3`,
|
||||
# which means a 4-byte LOG_DATA frame produces a 1-byte payload
|
||||
# here, and `path_byte = pbuf.read(1)[0]` further down would raise
|
||||
# IndexError on the empty buffer. Populate sentinel values so the
|
||||
# caller's downstream `log_data['route_type']` etc. lookups don't
|
||||
# KeyError, then return early.
|
||||
if len(payload) < 2:
|
||||
logger.debug(f"parsePacketPayload: payload too short ({len(payload)} bytes < 2), returning sentinel log_data")
|
||||
log_data["route_type"] = -1
|
||||
log_data["route_typename"] = "UNK"
|
||||
log_data["payload_type"] = -1
|
||||
log_data["payload_typename"] = "UNK"
|
||||
log_data["payload_ver"] = 0
|
||||
log_data["path_len"] = 0
|
||||
log_data["path_hash_size"] = 1
|
||||
log_data["path"] = ""
|
||||
log_data["pkt_payload"] = b""
|
||||
log_data["pkt_hash"] = 0
|
||||
return log_data
|
||||
|
||||
pbuf = io.BytesIO(payload)
|
||||
|
||||
header = pbuf.read(1)[0]
|
||||
@@ -128,7 +150,7 @@ class MeshcorePacketParser:
|
||||
uncrypted = cipher.decrypt(msg)
|
||||
timestamp = int.from_bytes(uncrypted[0:4], "little", signed=False)
|
||||
attempt = uncrypted[4] & 3
|
||||
txt_type = int.from_bytes(uncrypted[4:4], "little", signed=False) >> 2
|
||||
txt_type = int.from_bytes(uncrypted[4:5], "little", signed=False) >> 2
|
||||
message = uncrypted[5:].strip(b"\0")
|
||||
msg_hash = int.from_bytes(SHA256.new(timestamp.to_bytes(4, "little", signed=False) + message).digest()[0:4], "little", signed=False)
|
||||
log_data["message"] = message.decode("utf-8", "ignore")
|
||||
@@ -149,39 +171,42 @@ class MeshcorePacketParser:
|
||||
del self.channels_log[:25]
|
||||
|
||||
elif not payload is None and payload_type == 0x04: # Advert
|
||||
pk_buf = io.BytesIO(pkt_payload)
|
||||
adv_key = pk_buf.read(32).hex()
|
||||
adv_timestamp = int.from_bytes(pk_buf.read(4), "little", signed=False)
|
||||
signature = pk_buf.read(64).hex()
|
||||
flags = pk_buf.read(1)[0]
|
||||
adv_type = flags & 0x0F
|
||||
adv_lat = None
|
||||
adv_lon = None
|
||||
adv_feat1 = None
|
||||
adv_feat2 = None
|
||||
if flags & 0x10 > 0: #has location
|
||||
adv_lat = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0
|
||||
adv_lon = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0
|
||||
if flags & 0x20 > 0: #has feature1
|
||||
adv_feat1 = pk_buf.read(2).hex()
|
||||
if flags & 0x40 > 0: #has feature2
|
||||
adv_feat2 = pk_buf.read(2).hex()
|
||||
if flags & 0x80 > 0: #has name
|
||||
adv_name = pk_buf.read().decode("utf-8", "ignore").strip("\x00")
|
||||
log_data["adv_name"] = adv_name
|
||||
try:
|
||||
pk_buf = io.BytesIO(pkt_payload)
|
||||
adv_key = pk_buf.read(32).hex()
|
||||
adv_timestamp = int.from_bytes(pk_buf.read(4), "little", signed=False)
|
||||
signature = pk_buf.read(64).hex()
|
||||
flags = pk_buf.read(1)[0]
|
||||
adv_type = flags & 0x0F
|
||||
adv_lat = None
|
||||
adv_lon = None
|
||||
adv_feat1 = None
|
||||
adv_feat2 = None
|
||||
if flags & 0x10 > 0: #has location
|
||||
adv_lat = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0
|
||||
adv_lon = int.from_bytes(pk_buf.read(4), "little", signed=True)/1000000.0
|
||||
if flags & 0x20 > 0: #has feature1
|
||||
adv_feat1 = pk_buf.read(2).hex()
|
||||
if flags & 0x40 > 0: #has feature2
|
||||
adv_feat2 = pk_buf.read(2).hex()
|
||||
if flags & 0x80 > 0: #has name
|
||||
adv_name = pk_buf.read().decode("utf-8", "ignore").strip("\x00")
|
||||
log_data["adv_name"] = adv_name
|
||||
|
||||
log_data["adv_key"] = adv_key
|
||||
log_data["adv_timestamp"] = adv_timestamp
|
||||
log_data["signature"] = signature
|
||||
log_data["adv_flags"] = flags
|
||||
log_data["adv_type"] = adv_type
|
||||
if not adv_lat is None :
|
||||
log_data["adv_lat"] = adv_lat
|
||||
if not adv_lon is None :
|
||||
log_data["adv_lon"] = adv_lon
|
||||
if not adv_feat1 is None:
|
||||
log_data["adv_feat1"] = adv_feat1
|
||||
if not adv_feat2 is None:
|
||||
log_data["adv_feat2"] = adv_feat2
|
||||
log_data["adv_key"] = adv_key
|
||||
log_data["adv_timestamp"] = adv_timestamp
|
||||
log_data["signature"] = signature
|
||||
log_data["adv_flags"] = flags
|
||||
log_data["adv_type"] = adv_type
|
||||
if not adv_lat is None :
|
||||
log_data["adv_lat"] = adv_lat
|
||||
if not adv_lon is None :
|
||||
log_data["adv_lon"] = adv_lon
|
||||
if not adv_feat1 is None:
|
||||
log_data["adv_feat1"] = adv_feat1
|
||||
if not adv_feat2 is None:
|
||||
log_data["adv_feat2"] = adv_feat2
|
||||
except (IndexError, ValueError) as e:
|
||||
logger.debug(f"parsePacketPayload: malformed ADVERT payload ({type(e).__name__}: {e}), len={len(pkt_payload)}")
|
||||
|
||||
return log_data
|
||||
|
||||
@@ -71,6 +71,7 @@ class CommandType(Enum):
|
||||
SET_AUTOADD_CONFIG = 58
|
||||
GET_AUTOADD_CONFIG = 59
|
||||
GET_ALLOWED_REPEAT_FREQ = 60
|
||||
GET_STATS = 56 # R04: CMD_GET_STATS — used by get_stats_core/radio/packets
|
||||
SET_PATH_HASH_MODE = 61
|
||||
|
||||
# Packet prefixes for the protocol
|
||||
@@ -120,3 +121,6 @@ class PacketType(Enum):
|
||||
PATH_DISCOVERY_RESPONSE = 0x8D
|
||||
CONTROL_DATA = 0x8E
|
||||
CONTACT_DELETED = 0x8F
|
||||
CONTACTS_FULL = 0x90 # N02: MyMesh::onContactsFull() — 1-byte push, no payload
|
||||
# Note: 0x90 == ControlType.NODE_DISCOVER_RESP in a different namespace.
|
||||
# Not a literal conflict (PacketType vs ControlType), but a maintenance hazard.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,11 +20,19 @@ class SerialConnection:
|
||||
self._disconnect_callback = None
|
||||
self.cx_dly = cx_dly
|
||||
self._connected_event = asyncio.Event()
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
self.frame_expected_size = 0
|
||||
self.inframe = b""
|
||||
self.header = b""
|
||||
|
||||
def _spawn_background(self, coro) -> asyncio.Task:
|
||||
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return task
|
||||
|
||||
class MCSerialClientProtocol(asyncio.Protocol):
|
||||
def __init__(self, cx):
|
||||
self.cx = cx
|
||||
@@ -44,7 +52,7 @@ class SerialConnection:
|
||||
self.cx._connected_event.clear()
|
||||
|
||||
if self.cx._disconnect_callback:
|
||||
asyncio.create_task(self.cx._disconnect_callback("serial_disconnect"))
|
||||
self.cx._spawn_background(self.cx._disconnect_callback("serial_disconnect"))
|
||||
|
||||
def pause_writing(self):
|
||||
logger.debug("pause writing")
|
||||
@@ -52,9 +60,13 @@ class SerialConnection:
|
||||
def resume_writing(self):
|
||||
logger.debug("resume writing")
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self, timeout: float = 10.0):
|
||||
"""
|
||||
Connects to the device
|
||||
Connects to the device.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for connection_made callback.
|
||||
Defaults to 10.0. Raises asyncio.TimeoutError on expiry.
|
||||
"""
|
||||
self._connected_event.clear()
|
||||
|
||||
@@ -66,7 +78,7 @@ class SerialConnection:
|
||||
baudrate=self.baudrate,
|
||||
)
|
||||
|
||||
await self._connected_event.wait()
|
||||
await asyncio.wait_for(self._connected_event.wait(), timeout=timeout)
|
||||
logger.info("Serial Connection started")
|
||||
return self.port
|
||||
|
||||
@@ -102,7 +114,7 @@ class SerialConnection:
|
||||
self.frame_expected_size = 0
|
||||
if len(data) > 0: # rerun handle_rx on remaining data
|
||||
self.handle_rx(data)
|
||||
return
|
||||
return # nothing left to process after reset
|
||||
|
||||
upbound = self.frame_expected_size - len(self.inframe)
|
||||
if len(data) < upbound:
|
||||
@@ -114,7 +126,7 @@ class SerialConnection:
|
||||
data = data[upbound:]
|
||||
if self.reader is not None:
|
||||
# feed meshcore reader
|
||||
asyncio.create_task(self.reader.handle_rx(self.inframe))
|
||||
self._spawn_background(self.reader.handle_rx(self.inframe))
|
||||
# reset inframe
|
||||
self.inframe = b""
|
||||
self.header = b""
|
||||
@@ -125,11 +137,18 @@ class SerialConnection:
|
||||
async def send(self, data):
|
||||
if not self.transport:
|
||||
logger.error("Transport not connected, cannot send data")
|
||||
if self._disconnect_callback:
|
||||
await self._disconnect_callback("serial_transport_lost")
|
||||
return
|
||||
size = len(data)
|
||||
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
||||
logger.debug(f"sending pkt : {pkt}")
|
||||
self.transport.write(pkt)
|
||||
try:
|
||||
self.transport.write(pkt)
|
||||
except OSError as exc:
|
||||
logger.warning(f"Serial write failed: {exc}")
|
||||
if self._disconnect_callback:
|
||||
await self._disconnect_callback(f"serial_write_failed: {exc}")
|
||||
|
||||
async def disconnect(self):
|
||||
"""Close the serial connection."""
|
||||
|
||||
@@ -24,6 +24,14 @@ class TCPConnection:
|
||||
self.frame_expected_size = 0
|
||||
self.header = b""
|
||||
self.inframe = b""
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def _spawn_background(self, coro) -> asyncio.Task:
|
||||
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
return task
|
||||
|
||||
class MCClientProtocol(asyncio.Protocol):
|
||||
def __init__(self, cx):
|
||||
@@ -38,7 +46,6 @@ class TCPConnection:
|
||||
|
||||
def data_received(self, data):
|
||||
logger.debug("data received")
|
||||
self.cx._receive_count += 1
|
||||
self.cx.handle_rx(data)
|
||||
|
||||
def error_received(self, exc):
|
||||
@@ -47,7 +54,7 @@ class TCPConnection:
|
||||
def connection_lost(self, exc):
|
||||
logger.debug("TCP server closed the connection")
|
||||
if self.cx._disconnect_callback:
|
||||
asyncio.create_task(self.cx._disconnect_callback("tcp_disconnect"))
|
||||
self.cx._spawn_background(self.cx._disconnect_callback("tcp_disconnect"))
|
||||
|
||||
async def connect(self):
|
||||
"""
|
||||
@@ -59,10 +66,7 @@ class TCPConnection:
|
||||
)
|
||||
|
||||
logger.info("TCP Connection started")
|
||||
future = asyncio.Future()
|
||||
future.set_result(self.host)
|
||||
|
||||
return future
|
||||
return self.host
|
||||
|
||||
def set_reader(self, reader):
|
||||
self.reader = reader
|
||||
@@ -96,7 +100,7 @@ class TCPConnection:
|
||||
self.frame_expected_size = 0
|
||||
if len(data) > 0: # rerun handle_rx on remaining data
|
||||
self.handle_rx(data)
|
||||
return
|
||||
return # nothing left to process after reset
|
||||
|
||||
upbound = self.frame_expected_size - len(self.inframe)
|
||||
if len(data) < upbound :
|
||||
@@ -106,9 +110,13 @@ class TCPConnection:
|
||||
|
||||
self.inframe = self.inframe + data[0:upbound]
|
||||
data = data[upbound:]
|
||||
# Increment per completed MeshCore frame, not per TCP segment (N04).
|
||||
# The threshold heuristic in send() compares _send_count vs
|
||||
# _receive_count — counting per-segment skews it under fragmentation.
|
||||
self._receive_count += 1
|
||||
if self.reader is not None:
|
||||
# feed meshcore reader
|
||||
asyncio.create_task(self.reader.handle_rx(self.inframe))
|
||||
self._spawn_background(self.reader.handle_rx(self.inframe))
|
||||
# reset inframe
|
||||
self.inframe = b""
|
||||
self.header = b""
|
||||
@@ -137,7 +145,12 @@ class TCPConnection:
|
||||
size = len(data)
|
||||
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
||||
logger.debug(f"sending pkt : {pkt}")
|
||||
self.transport.write(pkt)
|
||||
try:
|
||||
self.transport.write(pkt)
|
||||
except (OSError, ConnectionResetError) as exc:
|
||||
logger.warning(f"TCP write failed: {exc}")
|
||||
if self._disconnect_callback:
|
||||
await self._disconnect_callback(f"tcp_write_failed: {exc}")
|
||||
|
||||
async def disconnect(self):
|
||||
"""Close the TCP connection."""
|
||||
|
||||
@@ -37,7 +37,7 @@ class TestBLEPinPairing(unittest.TestCase):
|
||||
|
||||
@patch("meshcore.ble_cx.BleakClient")
|
||||
def test_ble_connection_with_pin_failed_pairing(self, mock_bleak_client):
|
||||
"""Test BLE connection with PIN when pairing fails but connection continues"""
|
||||
"""Test BLE connection with PIN when pairing fails — re-raises (F17)."""
|
||||
# Arrange
|
||||
mock_client_instance = self._get_mock_bleak_client()
|
||||
mock_client_instance.pair = AsyncMock(side_effect=Exception("Pairing failed"))
|
||||
@@ -47,17 +47,16 @@ class TestBLEPinPairing(unittest.TestCase):
|
||||
pin = "123456"
|
||||
ble_conn = BLEConnection(address=address, pin=pin)
|
||||
|
||||
# Act
|
||||
result = asyncio.run(ble_conn.connect())
|
||||
|
||||
# Assert
|
||||
# Act & Assert — pairing failure now re-raises instead of being
|
||||
# swallowed, because a half-usable transport is worse than a clean
|
||||
# failure (forensics finding F17).
|
||||
with self.assertRaises(Exception) as ctx:
|
||||
asyncio.run(ble_conn.connect())
|
||||
self.assertIn("Pairing failed", str(ctx.exception))
|
||||
mock_client_instance.connect.assert_called_once()
|
||||
mock_client_instance.pair.assert_called_once()
|
||||
mock_client_instance.start_notify.assert_called_once_with(
|
||||
UART_TX_CHAR_UUID, ble_conn.handle_rx
|
||||
)
|
||||
# Connection should still succeed even if pairing fails
|
||||
self.assertEqual(result, address)
|
||||
# disconnect should be called to clean up the failed connection
|
||||
mock_client_instance.disconnect.assert_called_once()
|
||||
|
||||
@patch("meshcore.ble_cx.BleakClient")
|
||||
def test_ble_connection_without_pin_no_pairing(self, mock_bleak_client):
|
||||
|
||||
235
tests/unit/test_asyncio_lifecycle.py
Normal file
235
tests/unit/test_asyncio_lifecycle.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Verification tests for asyncio lifecycle fixes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from meshcore.events import Event, EventDispatcher, EventType
|
||||
from meshcore.tcp_cx import TCPConnection
|
||||
from meshcore.serial_cx import SerialConnection
|
||||
from meshcore.commands.base import CommandHandlerBase
|
||||
|
||||
|
||||
class TestBackgroundTaskTracking(unittest.TestCase):
|
||||
"""Fire-and-forget create_task calls must be tracked to prevent GC."""
|
||||
|
||||
def test_tcp_spawn_background_retains_task(self):
|
||||
"""TCP _spawn_background adds the task to _background_tasks."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
completed = asyncio.Event()
|
||||
|
||||
async def dummy():
|
||||
completed.set()
|
||||
|
||||
task = cx._spawn_background(dummy())
|
||||
assert task in cx._background_tasks
|
||||
await completed.wait()
|
||||
# After completion, done_callback should have discarded it
|
||||
await asyncio.sleep(0) # let done callback fire
|
||||
assert task not in cx._background_tasks
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_serial_spawn_background_retains_task(self):
|
||||
"""Serial _spawn_background adds the task to _background_tasks."""
|
||||
async def _run():
|
||||
with patch("meshcore.serial_cx.asyncio.Event") as mock_event:
|
||||
mock_event.return_value = MagicMock()
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
completed = asyncio.Event()
|
||||
|
||||
async def dummy():
|
||||
completed.set()
|
||||
|
||||
task = cx._spawn_background(dummy())
|
||||
assert task in cx._background_tasks
|
||||
await completed.wait()
|
||||
await asyncio.sleep(0)
|
||||
assert task not in cx._background_tasks
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_event_dispatcher_spawn_background_retains_task(self):
|
||||
"""EventDispatcher _spawn_background adds task to _background_tasks."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
completed = asyncio.Event()
|
||||
|
||||
async def dummy():
|
||||
completed.set()
|
||||
|
||||
task = dispatcher._spawn_background(dummy())
|
||||
assert task in dispatcher._background_tasks
|
||||
await completed.wait()
|
||||
await asyncio.sleep(0)
|
||||
assert task not in dispatcher._background_tasks
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_tcp_handle_rx_uses_tracked_task(self):
|
||||
"""TCP handle_rx dispatches reader.handle_rx via _spawn_background."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
reader = AsyncMock()
|
||||
reader.handle_rx = AsyncMock()
|
||||
cx.set_reader(reader)
|
||||
|
||||
# Build a minimal valid frame: 0x3e + 2-byte LE size + payload
|
||||
payload = b"\x01\x02\x03"
|
||||
size = len(payload).to_bytes(2, "little")
|
||||
frame = b"\x3e" + size + payload
|
||||
|
||||
cx.handle_rx(frame)
|
||||
# Task should be tracked
|
||||
assert len(cx._background_tasks) == 1
|
||||
# Let task complete
|
||||
await asyncio.sleep(0.05)
|
||||
reader.handle_rx.assert_awaited_once_with(payload)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_tcp_connection_lost_uses_tracked_task(self):
|
||||
"""TCP connection_lost dispatches disconnect callback via _spawn_background."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
callback = AsyncMock()
|
||||
cx.set_disconnect_callback(callback)
|
||||
|
||||
protocol = cx.MCClientProtocol(cx)
|
||||
protocol.connection_lost(None)
|
||||
|
||||
assert len(cx._background_tasks) == 1
|
||||
await asyncio.sleep(0.05)
|
||||
callback.assert_awaited_once_with("tcp_disconnect")
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_gc_does_not_cancel_tracked_tasks(self):
|
||||
"""Tracked tasks survive GC pressure (the whole point of tracking)."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
result = []
|
||||
|
||||
async def slow_task():
|
||||
await asyncio.sleep(0.05)
|
||||
result.append("done")
|
||||
|
||||
cx._spawn_background(slow_task())
|
||||
# Force GC — untracked tasks could be collected here
|
||||
gc.collect()
|
||||
await asyncio.sleep(0.1)
|
||||
assert result == ["done"]
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
class TestTaskDoneCorrectness(unittest.TestCase):
|
||||
"""EventDispatcher.stop() must wait for in-flight async callbacks."""
|
||||
|
||||
def test_stop_waits_for_async_callbacks(self):
|
||||
"""stop() should not return until async callbacks have completed."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
|
||||
callback_completed = False
|
||||
|
||||
async def slow_callback(event):
|
||||
nonlocal callback_completed
|
||||
await asyncio.sleep(0.1)
|
||||
callback_completed = True
|
||||
|
||||
dispatcher.subscribe(EventType.OK, slow_callback)
|
||||
await dispatcher.dispatch(Event(EventType.OK, {}))
|
||||
|
||||
# Give the dispatch loop a moment to pick up the event
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
# stop() should wait for slow_callback to finish
|
||||
await dispatcher.stop()
|
||||
assert callback_completed, "stop() returned before async callback completed"
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
class TestDeferredPrimitiveConstruction(unittest.TestCase):
|
||||
"""Queue and Lock must not bind to import-time loop."""
|
||||
|
||||
def test_event_dispatcher_queue_is_none_before_start(self):
|
||||
"""EventDispatcher.queue should be None until start() is called."""
|
||||
dispatcher = EventDispatcher()
|
||||
assert dispatcher.queue is None
|
||||
|
||||
def test_event_dispatcher_queue_created_on_start(self):
|
||||
"""start() creates the queue."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
assert dispatcher.queue is None
|
||||
await dispatcher.start()
|
||||
assert dispatcher.queue is not None
|
||||
assert isinstance(dispatcher.queue, asyncio.Queue)
|
||||
await dispatcher.stop()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_event_dispatcher_dispatch_before_start_raises(self):
|
||||
"""dispatch() before start() should raise RuntimeError."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
with self.assertRaises(RuntimeError):
|
||||
await dispatcher.dispatch(Event(EventType.OK, {}))
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_command_handler_lock_is_none_before_use(self):
|
||||
"""CommandHandlerBase lock should be None until first access."""
|
||||
handler = CommandHandlerBase()
|
||||
assert handler._CommandHandlerBase__mesh_request_lock is None
|
||||
|
||||
def test_command_handler_lock_created_on_access(self):
|
||||
"""Accessing _mesh_request_lock creates it lazily."""
|
||||
async def _run():
|
||||
handler = CommandHandlerBase()
|
||||
lock = handler._mesh_request_lock
|
||||
assert isinstance(lock, asyncio.Lock)
|
||||
# Second access returns same instance
|
||||
assert handler._mesh_request_lock is lock
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
class TestGetRunningLoop(unittest.TestCase):
|
||||
"""get_event_loop() replaced with get_running_loop() in send()."""
|
||||
|
||||
def test_send_uses_get_running_loop(self):
|
||||
"""send() should call get_running_loop, not get_event_loop."""
|
||||
async def _run():
|
||||
handler = CommandHandlerBase()
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
handler.set_dispatcher(dispatcher)
|
||||
|
||||
mock_sender = AsyncMock()
|
||||
handler._sender_func = mock_sender
|
||||
|
||||
# Patch get_running_loop to verify it's called
|
||||
with patch("meshcore.commands.base.asyncio.get_running_loop", wraps=asyncio.get_running_loop) as mock_grl:
|
||||
# send with expected_events triggers the loop = asyncio.get_running_loop() path
|
||||
result = await handler.send(
|
||||
b"\x01",
|
||||
expected_events=[EventType.OK],
|
||||
timeout=0.05,
|
||||
)
|
||||
mock_grl.assert_called()
|
||||
|
||||
await dispatcher.stop()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -28,6 +28,10 @@ def mock_dispatcher():
|
||||
sub.unsubscribe = MagicMock()
|
||||
dispatcher._last_subscribe_handler = handler
|
||||
dispatcher._last_subscribe_event_type = event_type
|
||||
# Immediately resolve the future so send() doesn't block
|
||||
asyncio.get_event_loop().call_soon(
|
||||
handler, Event(event_type, {})
|
||||
)
|
||||
return sub
|
||||
|
||||
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||
@@ -80,6 +84,13 @@ async def test_send_with_event(command_handler, mock_connection, mock_dispatcher
|
||||
|
||||
|
||||
async def test_send_timeout(command_handler, mock_connection, mock_dispatcher):
|
||||
# Override to NOT resolve events, so we can test the timeout path
|
||||
def non_resolving_subscribe(event_type, handler, attribute_filters=None):
|
||||
sub = MagicMock(spec=Subscription)
|
||||
sub.unsubscribe = MagicMock()
|
||||
return sub
|
||||
mock_dispatcher.subscribe = MagicMock(side_effect=non_resolving_subscribe)
|
||||
|
||||
result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1)
|
||||
assert result.type == EventType.ERROR
|
||||
assert result.payload == {"reason": "no_event_received"}
|
||||
|
||||
294
tests/unit/test_connection_manager.py
Normal file
294
tests/unit/test_connection_manager.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Tests for reconnect-path fixes."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from meshcore.connection_manager import ConnectionManager
|
||||
from meshcore.events import Event, EventDispatcher, EventType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class FakeConnection:
|
||||
"""Minimal stub that satisfies ConnectionProtocol."""
|
||||
|
||||
def __init__(self, connect_results=None):
|
||||
"""
|
||||
Args:
|
||||
connect_results: iterator of return values for successive
|
||||
connect() calls. ``None`` means soft failure; a string
|
||||
means success; raising is also supported via sentinel.
|
||||
"""
|
||||
self._connect_results = list(connect_results or ["ok"])
|
||||
self._call_index = 0
|
||||
self.reader = None
|
||||
|
||||
async def connect(self):
|
||||
if self._call_index < len(self._connect_results):
|
||||
result = self._connect_results[self._call_index]
|
||||
self._call_index += 1
|
||||
else:
|
||||
result = self._connect_results[-1]
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, data):
|
||||
pass
|
||||
|
||||
def set_reader(self, reader):
|
||||
self.reader = reader
|
||||
|
||||
|
||||
class RaisingConnection(FakeConnection):
|
||||
"""Connection that raises on every connect() attempt."""
|
||||
|
||||
def __init__(self, exc=None):
|
||||
super().__init__()
|
||||
self._exc = exc or ConnectionError("boom")
|
||||
|
||||
async def connect(self):
|
||||
raise self._exc
|
||||
|
||||
|
||||
class _EventCollector:
|
||||
"""Subscribes to all events and records them."""
|
||||
|
||||
def __init__(self, dispatcher: EventDispatcher):
|
||||
self.events: list[Event] = []
|
||||
dispatcher.subscribe(None, self._on_event)
|
||||
|
||||
async def _on_event(self, event: Event):
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TCP connect() should return a plain value, not an asyncio.Future
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_connect_returns_plain_string():
|
||||
"""TCPConnection.connect() returns self.host (a plain string), not an
|
||||
asyncio.Future. We test indirectly via ConnectionManager — the
|
||||
CONNECTED event payload should contain a plain string, not a Future
|
||||
object."""
|
||||
conn = FakeConnection(connect_results=["10.0.0.1"])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
collector = _EventCollector(dispatcher)
|
||||
mgr = ConnectionManager(conn, dispatcher)
|
||||
|
||||
result = await mgr.connect()
|
||||
|
||||
assert result == "10.0.0.1"
|
||||
# Give the dispatcher a moment to deliver the event
|
||||
await asyncio.sleep(0.05)
|
||||
connected_events = [e for e in collector.events if e.type == EventType.CONNECTED]
|
||||
assert len(connected_events) == 1
|
||||
payload = connected_events[0].payload
|
||||
assert payload["connection_info"] == "10.0.0.1"
|
||||
# The payload value must NOT be an asyncio.Future
|
||||
assert not isinstance(payload["connection_info"], asyncio.Future)
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reconnect attempts must not compound (no tail-recursive create_task)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_loop_does_not_compound():
|
||||
"""_attempt_reconnect must use a single iterative loop. After
|
||||
max_reconnect_attempts failures, exactly that many connect() calls
|
||||
should have been made — no exponential fan-out from orphaned tasks."""
|
||||
# All attempts fail (return None)
|
||||
conn = FakeConnection(connect_results=[None, None, None, None])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
collector = _EventCollector(dispatcher)
|
||||
mgr = ConnectionManager(
|
||||
conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=3,
|
||||
)
|
||||
mgr._is_connected = True # simulate a live connection
|
||||
|
||||
await mgr.handle_disconnect("test_disconnect")
|
||||
# Wait for the reconnect loop to exhaust all attempts
|
||||
# (3 attempts × 1s sleep each, but we can just await the task)
|
||||
if mgr._reconnect_task:
|
||||
await mgr._reconnect_task
|
||||
|
||||
# Exactly 3 connect() calls should have been made
|
||||
assert conn._call_index == 3
|
||||
|
||||
# A DISCONNECTED event with max_attempts_exceeded should have fired
|
||||
await asyncio.sleep(0.05)
|
||||
disconnected = [e for e in collector.events if e.type == EventType.DISCONNECTED]
|
||||
assert len(disconnected) == 1
|
||||
assert disconnected[0].payload.get("max_attempts_exceeded") is True
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cancels_reconnect_loop():
|
||||
"""disconnect() during an active reconnect loop must cancel the
|
||||
single task cleanly — no orphaned tasks left running."""
|
||||
# Simulate a connection that always fails (returns None), giving us
|
||||
# time to call disconnect() mid-loop.
|
||||
conn = FakeConnection(connect_results=[None, None, None, None, None])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
mgr = ConnectionManager(
|
||||
conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=5,
|
||||
)
|
||||
mgr._is_connected = True
|
||||
|
||||
await mgr.handle_disconnect("test_disconnect")
|
||||
|
||||
# Let the first attempt start (wait just past the 1s sleep)
|
||||
await asyncio.sleep(1.2)
|
||||
assert conn._call_index >= 1 # at least one attempt made
|
||||
|
||||
# Now disconnect — should cancel the loop
|
||||
await mgr.disconnect()
|
||||
|
||||
assert mgr._reconnect_task is None
|
||||
calls_at_cancel = conn._call_index
|
||||
|
||||
# Wait a bit and confirm no more attempts happened
|
||||
await asyncio.sleep(2)
|
||||
assert conn._call_index == calls_at_cancel
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reconnect_callback (send_appstart) is called after reconnect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_callback_called_after_reconnect():
|
||||
"""When ConnectionManager reconnects successfully, the
|
||||
reconnect_callback (e.g. send_appstart) must be invoked."""
|
||||
callback_called = []
|
||||
|
||||
async def fake_appstart():
|
||||
callback_called.append(True)
|
||||
|
||||
# First connect() fails (None), second succeeds
|
||||
conn = FakeConnection(connect_results=[None, "10.0.0.1"])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
mgr = ConnectionManager(
|
||||
conn, dispatcher,
|
||||
auto_reconnect=True,
|
||||
max_reconnect_attempts=3,
|
||||
reconnect_callback=fake_appstart,
|
||||
)
|
||||
mgr._is_connected = True
|
||||
|
||||
await mgr.handle_disconnect("test_disconnect")
|
||||
if mgr._reconnect_task:
|
||||
await mgr._reconnect_task
|
||||
|
||||
assert len(callback_called) == 1
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_callback_failure_does_not_crash_loop():
|
||||
"""If the reconnect_callback raises, the reconnect still counts as
|
||||
successful (transport is up) — the callback failure is logged but
|
||||
does not crash the loop or leave the manager in a broken state."""
|
||||
async def failing_callback():
|
||||
raise RuntimeError("appstart failed")
|
||||
|
||||
# connect() succeeds on first attempt
|
||||
conn = FakeConnection(connect_results=["10.0.0.1"])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
collector = _EventCollector(dispatcher)
|
||||
mgr = ConnectionManager(
|
||||
conn, dispatcher,
|
||||
auto_reconnect=True,
|
||||
max_reconnect_attempts=3,
|
||||
reconnect_callback=failing_callback,
|
||||
)
|
||||
mgr._is_connected = True
|
||||
|
||||
await mgr.handle_disconnect("test_disconnect")
|
||||
if mgr._reconnect_task:
|
||||
await mgr._reconnect_task
|
||||
|
||||
# Despite callback failure, CONNECTED event should have fired
|
||||
await asyncio.sleep(0.05)
|
||||
connected = [e for e in collector.events if e.type == EventType.CONNECTED]
|
||||
assert len(connected) == 1
|
||||
assert mgr._is_connected is True
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# connect() returning None is a soft failure (BLE scan miss)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_none_is_soft_failure():
|
||||
"""When connect() returns None (e.g. BLE scan found no device),
|
||||
ConnectionManager.connect() should NOT set _is_connected and should
|
||||
NOT emit a CONNECTED event."""
|
||||
conn = FakeConnection(connect_results=[None])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
collector = _EventCollector(dispatcher)
|
||||
mgr = ConnectionManager(conn, dispatcher)
|
||||
|
||||
result = await mgr.connect()
|
||||
|
||||
assert result is None
|
||||
assert mgr._is_connected is False
|
||||
await asyncio.sleep(0.05)
|
||||
connected = [e for e in collector.events if e.type == EventType.CONNECTED]
|
||||
assert len(connected) == 0
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_reconnect_callback_is_noop():
|
||||
"""When no reconnect_callback is provided (backwards compat for
|
||||
direct ConnectionManager users), reconnect should still work."""
|
||||
conn = FakeConnection(connect_results=["10.0.0.1"])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
mgr = ConnectionManager(
|
||||
conn, dispatcher,
|
||||
auto_reconnect=True,
|
||||
max_reconnect_attempts=3,
|
||||
# No reconnect_callback — default None
|
||||
)
|
||||
mgr._is_connected = True
|
||||
|
||||
await mgr.handle_disconnect("test_disconnect")
|
||||
if mgr._reconnect_task:
|
||||
await mgr._reconnect_task
|
||||
|
||||
assert mgr._is_connected is True
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
236
tests/unit/test_error_handling.py
Normal file
236
tests/unit/test_error_handling.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Verification tests for error response handling fixes.
|
||||
|
||||
The tests confirm that error responses are surfaced cleanly instead
|
||||
of causing KeyError, TypeError, NameError, or silent fallthrough.
|
||||
"""
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from meshcore.commands import CommandHandler
|
||||
from meshcore.events import EventType, Event, Subscription
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
VALID_PUBKEY_HEX = "0123456789abcdef" * 4 # 64 hex chars = 32 bytes
|
||||
|
||||
|
||||
# ── Fixtures ───────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection():
|
||||
connection = MagicMock()
|
||||
connection.send = AsyncMock()
|
||||
return connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dispatcher():
|
||||
dispatcher = MagicMock()
|
||||
dispatcher.wait_for_event = AsyncMock()
|
||||
dispatcher.dispatch = AsyncMock()
|
||||
|
||||
def fake_subscribe(event_type, handler, attribute_filters=None):
|
||||
sub = MagicMock(spec=Subscription)
|
||||
sub.unsubscribe = MagicMock()
|
||||
dispatcher._last_subscribe_handler = handler
|
||||
dispatcher._last_subscribe_event_type = event_type
|
||||
return sub
|
||||
|
||||
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||
return dispatcher
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def command_handler(mock_connection, mock_dispatcher):
|
||||
handler = CommandHandler()
|
||||
|
||||
async def sender(data):
|
||||
await mock_connection.send(data)
|
||||
|
||||
handler._sender_func = sender
|
||||
handler.dispatcher = mock_dispatcher
|
||||
return handler
|
||||
|
||||
|
||||
def setup_error_response(mock_dispatcher):
|
||||
"""Configure dispatcher to return an ERROR event for any subscribe."""
|
||||
def fake_subscribe(evt_type, handler, attr_filters=None):
|
||||
sub = MagicMock(spec=Subscription)
|
||||
sub.unsubscribe = MagicMock()
|
||||
# Always fire ERROR regardless of which event type was subscribed
|
||||
if evt_type == EventType.ERROR:
|
||||
asyncio.get_event_loop().call_soon(
|
||||
handler, Event(EventType.ERROR, {"reason": "test_error"})
|
||||
)
|
||||
return sub
|
||||
|
||||
mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||
|
||||
|
||||
def setup_event_response(mock_dispatcher, event_type, payload):
|
||||
"""Configure dispatcher to return a specific event."""
|
||||
def fake_subscribe(evt_type, handler, attr_filters=None):
|
||||
sub = MagicMock(spec=Subscription)
|
||||
sub.unsubscribe = MagicMock()
|
||||
if evt_type == event_type:
|
||||
asyncio.get_event_loop().call_soon(
|
||||
handler, Event(event_type, payload)
|
||||
)
|
||||
return sub
|
||||
|
||||
mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||
|
||||
|
||||
# ── Event.is_error() helper ──────────────────────────────────
|
||||
|
||||
async def test_event_is_error_true():
|
||||
"""is_error() returns True for ERROR events."""
|
||||
event = Event(EventType.ERROR, {"reason": "test"})
|
||||
assert event.is_error() is True
|
||||
|
||||
|
||||
async def test_event_is_error_false():
|
||||
"""is_error() returns False for non-ERROR events."""
|
||||
event = Event(EventType.OK, {})
|
||||
assert event.is_error() is False
|
||||
event2 = Event(EventType.SELF_INFO, {"name": "test"})
|
||||
assert event2.is_error() is False
|
||||
|
||||
|
||||
# ── send_msg_with_retry continues on ERROR ──────────────
|
||||
|
||||
async def test_send_msg_with_retry_error_no_keyerror(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""send_msg_with_retry returns None (exhausted retries) on
|
||||
persistent ERROR instead of raising KeyError on missing 'expected_ack'."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
# Provide a mock contact so the path logic doesn't interfere
|
||||
command_handler._get_contact_by_prefix = MagicMock(return_value=None)
|
||||
|
||||
# max_attempts=2 so it retries once then gives up
|
||||
result = await command_handler.send_msg_with_retry(
|
||||
VALID_PUBKEY_HEX, "hello", max_attempts=2, timeout=0.1
|
||||
)
|
||||
|
||||
# Should return None (no ACK received) rather than raising KeyError
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── send_appstart includes ERROR in expected events ──────────
|
||||
|
||||
async def test_send_appstart_returns_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""send_appstart returns ERROR event instead of hanging on timeout."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.send_appstart()
|
||||
|
||||
assert result.type == EventType.ERROR
|
||||
assert result.is_error() is True
|
||||
assert result.payload["reason"] == "test_error"
|
||||
|
||||
|
||||
# ── device setters return ERROR from send_appstart ───────────
|
||||
|
||||
async def test_set_telemetry_mode_base_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_telemetry_mode_base returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_telemetry_mode_base(1)
|
||||
|
||||
assert result.is_error()
|
||||
assert result.payload["reason"] == "test_error"
|
||||
|
||||
|
||||
async def test_set_telemetry_mode_loc_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_telemetry_mode_loc returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_telemetry_mode_loc(1)
|
||||
|
||||
assert result.is_error()
|
||||
|
||||
|
||||
async def test_set_telemetry_mode_env_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_telemetry_mode_env returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_telemetry_mode_env(1)
|
||||
|
||||
assert result.is_error()
|
||||
|
||||
|
||||
async def test_set_manual_add_contacts_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_manual_add_contacts returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_manual_add_contacts(True)
|
||||
|
||||
assert result.is_error()
|
||||
|
||||
|
||||
async def test_set_advert_loc_policy_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_advert_loc_policy returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_advert_loc_policy(1)
|
||||
|
||||
assert result.is_error()
|
||||
|
||||
|
||||
async def test_set_multi_acks_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_multi_acks returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_multi_acks(1)
|
||||
|
||||
assert result.is_error()
|
||||
|
||||
|
||||
# ── send_anon_req returns ERROR on contact not found ─────────
|
||||
|
||||
async def test_send_anon_req_contact_not_found(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""send_anon_req returns ERROR event when contact prefix not found,
|
||||
instead of raising TypeError on NoneType subscript."""
|
||||
command_handler._get_contact_by_prefix = MagicMock(return_value=None)
|
||||
|
||||
result = await command_handler.send_anon_req(
|
||||
VALID_PUBKEY_HEX, MagicMock(value=1)
|
||||
)
|
||||
|
||||
assert result.is_error()
|
||||
assert result.payload["reason"] == "contact_not_found"
|
||||
|
||||
|
||||
# ── send_trace handles unknown path_hash_len without NameError ──
|
||||
|
||||
async def test_send_trace_unknown_path_hash_len(
|
||||
command_handler, mock_connection, mock_dispatcher
|
||||
):
|
||||
"""send_trace with a path whose segments don't match any known
|
||||
path_hash_len returns ERROR cleanly instead of NameError on 'e'."""
|
||||
# 5-char hex segments → path_hash_len = 2.5 → doesn't match 1,2,4,8
|
||||
result = await command_handler.send_trace(
|
||||
auth_code=0, tag=1, flags=None, path="abcde"
|
||||
)
|
||||
|
||||
assert result.is_error()
|
||||
assert result.payload["reason"] == "invalid_path_format"
|
||||
364
tests/unit/test_protocol_surface_gaps.py
Normal file
364
tests/unit/test_protocol_surface_gaps.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""Verification tests for protocol surface gaps.
|
||||
|
||||
Each test constructs a mock firmware frame and verifies the SDK dispatches
|
||||
the correct EventType with the expected payload fields.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from meshcore.events import Event, EventType, EventDispatcher
|
||||
from meshcore.reader import MessageReader
|
||||
from meshcore.packets import PacketType, CommandType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_reader():
|
||||
"""Create a MessageReader with a mock dispatcher that records dispatched events."""
|
||||
dispatcher = MagicMock(spec=EventDispatcher)
|
||||
dispatched = []
|
||||
|
||||
async def _capture(event):
|
||||
dispatched.append(event)
|
||||
|
||||
dispatcher.dispatch = AsyncMock(side_effect=_capture)
|
||||
reader = MessageReader(dispatcher)
|
||||
return reader, dispatched
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CONTACT_DELETED handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_contact_deleted_dispatches_event():
|
||||
"""A 33-byte CONTACT_DELETED frame dispatches EventType.CONTACT_DELETED."""
|
||||
reader, dispatched = _make_reader()
|
||||
pubkey = bytes(range(32))
|
||||
frame = bytes([PacketType.CONTACT_DELETED.value]) + pubkey
|
||||
assert len(frame) == 33
|
||||
|
||||
await reader.handle_rx(bytearray(frame))
|
||||
|
||||
assert len(dispatched) == 1
|
||||
evt = dispatched[0]
|
||||
assert evt.type == EventType.CONTACT_DELETED
|
||||
assert evt.payload["pubkey"] == pubkey.hex()
|
||||
assert evt.attributes["pubkey"] == pubkey.hex()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_contact_deleted_short_frame_ignored():
|
||||
"""A CONTACT_DELETED frame shorter than 33 bytes is silently dropped."""
|
||||
reader, dispatched = _make_reader()
|
||||
# Only 10 bytes — too short
|
||||
frame = bytes([PacketType.CONTACT_DELETED.value]) + b"\x00" * 9
|
||||
|
||||
await reader.handle_rx(bytearray(frame))
|
||||
|
||||
assert len(dispatched) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CONTACTS_FULL handler + enum entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_contacts_full_enum_exists():
|
||||
"""PacketType.CONTACTS_FULL == 0x90."""
|
||||
assert PacketType.CONTACTS_FULL.value == 0x90
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_contacts_full_dispatches_event():
|
||||
"""A 1-byte CONTACTS_FULL push dispatches EventType.CONTACTS_FULL."""
|
||||
reader, dispatched = _make_reader()
|
||||
frame = bytes([PacketType.CONTACTS_FULL.value])
|
||||
|
||||
await reader.handle_rx(bytearray(frame))
|
||||
|
||||
assert len(dispatched) == 1
|
||||
evt = dispatched[0]
|
||||
assert evt.type == EventType.CONTACTS_FULL
|
||||
assert evt.payload == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TUNING_PARAMS handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tuning_params_dispatches_event():
|
||||
"""A 9-byte TUNING_PARAMS frame dispatches with rx_delay and airtime_factor."""
|
||||
reader, dispatched = _make_reader()
|
||||
rx_delay = 500
|
||||
airtime_factor = 200
|
||||
frame = (
|
||||
bytes([PacketType.TUNING_PARAMS.value])
|
||||
+ rx_delay.to_bytes(4, "little")
|
||||
+ airtime_factor.to_bytes(4, "little")
|
||||
)
|
||||
assert len(frame) == 9
|
||||
|
||||
await reader.handle_rx(bytearray(frame))
|
||||
|
||||
assert len(dispatched) == 1
|
||||
evt = dispatched[0]
|
||||
assert evt.type == EventType.TUNING_PARAMS
|
||||
assert evt.payload["rx_delay"] == 500
|
||||
assert evt.payload["airtime_factor"] == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tuning_params_short_frame_dispatches_error():
|
||||
"""A TUNING_PARAMS frame shorter than 9 bytes dispatches ERROR."""
|
||||
reader, dispatched = _make_reader()
|
||||
# Only 5 bytes — too short
|
||||
frame = bytes([PacketType.TUNING_PARAMS.value]) + b"\x01\x00\x00\x00"
|
||||
|
||||
await reader.handle_rx(bytearray(frame))
|
||||
|
||||
assert len(dispatched) == 1
|
||||
evt = dispatched[0]
|
||||
assert evt.type == EventType.ERROR
|
||||
assert evt.payload["reason"] == "invalid_frame_length"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_trace() one-byte pad
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_trace_empty_path_pads_to_11_bytes():
|
||||
"""send_trace() with no path produces an 11-byte packet (not 10)."""
|
||||
from meshcore.commands.messaging import MessagingCommands
|
||||
|
||||
cmd = MessagingCommands.__new__(MessagingCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
await cmd.send_trace(auth_code=0, tag=1, flags=0, path=None)
|
||||
|
||||
assert captured_data is not None
|
||||
# cmd(1) + tag(4) + auth(4) + flags(1) + pad(1) = 11
|
||||
assert len(captured_data) == 11
|
||||
assert captured_data[-1] == 0x00 # The pad byte
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_trace_with_path_no_padding():
|
||||
"""send_trace() with a non-empty path does NOT add padding."""
|
||||
from meshcore.commands.messaging import MessagingCommands
|
||||
|
||||
cmd = MessagingCommands.__new__(MessagingCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
# 2-byte path hash (flags=1 means hash_len=2)
|
||||
await cmd.send_trace(auth_code=0, tag=1, flags=1, path=b"\xAA\xBB")
|
||||
|
||||
assert captured_data is not None
|
||||
# cmd(1) + tag(4) + auth(4) + flags(1) + path(2) = 12 — no pad needed
|
||||
assert len(captured_data) == 12
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command wrapper: send_raw_data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_raw_data_wrapper():
|
||||
"""send_raw_data sends CMD 0x19 + payload."""
|
||||
from meshcore.commands.messaging import MessagingCommands
|
||||
|
||||
cmd = MessagingCommands.__new__(MessagingCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
await cmd.send_raw_data(b"\xDE\xAD")
|
||||
|
||||
assert captured_data is not None
|
||||
assert captured_data[0] == 0x19 # CMD_SEND_RAW_DATA
|
||||
assert captured_data[1:] == b"\xDE\xAD"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command wrapper: has_connection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_connection_wrapper():
|
||||
"""has_connection sends CMD 0x1c."""
|
||||
from meshcore.commands.device import DeviceCommands
|
||||
|
||||
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.OK, {"value": 1})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
await cmd.has_connection()
|
||||
|
||||
assert captured_data is not None
|
||||
assert captured_data == b"\x1c"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command wrapper: get_tuning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tuning_wrapper():
|
||||
"""get_tuning sends CMD 0x2b (GET_TUNING_PARAMS = 43)."""
|
||||
from meshcore.commands.device import DeviceCommands
|
||||
|
||||
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.TUNING_PARAMS, {"rx_delay": 500, "airtime_factor": 200})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
result = await cmd.get_tuning()
|
||||
|
||||
assert captured_data == b"\x2b"
|
||||
assert result.type == EventType.TUNING_PARAMS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command wrapper: get_contact_by_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_contact_by_key_wrapper():
|
||||
"""get_contact_by_key sends CMD 0x1e + 32-byte pubkey."""
|
||||
from meshcore.commands.contact import ContactCommands
|
||||
|
||||
cmd = ContactCommands.__new__(ContactCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.NEXT_CONTACT, {"public_key": "ab" * 32})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
pubkey = bytes(range(32))
|
||||
await cmd.get_contact_by_key(pubkey)
|
||||
|
||||
assert captured_data is not None
|
||||
assert captured_data[0] == 0x1E # CMD_GET_CONTACT_BY_KEY
|
||||
assert captured_data[1:] == pubkey
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command wrapper: factory_reset (two-step)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_reset_two_step():
|
||||
"""factory_reset requires a token from request_factory_reset."""
|
||||
from meshcore.commands.device import DeviceCommands
|
||||
|
||||
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.OK, {})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
# Step 1: request token
|
||||
token = await cmd.request_factory_reset()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) == 16 # hex-encoded 8 bytes
|
||||
|
||||
# Step 2: confirm with wrong token fails
|
||||
with pytest.raises(ValueError, match="Invalid or expired"):
|
||||
await cmd.confirm_factory_reset("wrong_token")
|
||||
|
||||
# Step 2: confirm with correct token succeeds
|
||||
await cmd.confirm_factory_reset(token)
|
||||
assert captured_data == b"\x33" # CMD_FACTORY_RESET
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_reset_without_request_fails():
|
||||
"""confirm_factory_reset without request_factory_reset raises ValueError."""
|
||||
from meshcore.commands.device import DeviceCommands
|
||||
|
||||
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid or expired"):
|
||||
await cmd.confirm_factory_reset("any_token")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET_STATS enum entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_stats_enum_exists():
|
||||
"""CommandType.GET_STATS == 56."""
|
||||
assert CommandType.GET_STATS.value == 56
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_core_uses_enum():
|
||||
"""get_stats_core sends CommandType.GET_STATS.value (0x38) + 0x00."""
|
||||
from meshcore.commands.device import DeviceCommands
|
||||
|
||||
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.STATS_CORE, {})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
await cmd.get_stats_core()
|
||||
|
||||
assert captured_data is not None
|
||||
assert captured_data[0] == CommandType.GET_STATS.value # 0x38 = 56
|
||||
assert captured_data[1] == 0x00
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest.mock import AsyncMock
|
||||
from meshcore.events import EventType
|
||||
from meshcore.reader import MessageReader
|
||||
@@ -88,3 +89,192 @@ async def test_binary_response():
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_binary_response())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reader/parser crash-safety verification tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _CapturingDispatcher:
|
||||
"""Quiet dispatcher that records every dispatched event."""
|
||||
def __init__(self):
|
||||
self.events = []
|
||||
|
||||
async def dispatch(self, event):
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_rx_malformed_frame_logged_and_swallowed(caplog):
|
||||
"""Malformed frame must not propagate, must be logged with traceback."""
|
||||
dispatcher = _CapturingDispatcher()
|
||||
reader = MessageReader(dispatcher)
|
||||
|
||||
# 4-byte CHANNEL_MSG_RECV_V3 frame: type byte (0x11) + 1 SNR byte +
|
||||
# 2 reserved bytes, but no channel_idx byte. The handler will raise
|
||||
# IndexError on the next dbuf.read(1)[0] when the buffer is empty.
|
||||
# The umbrella try/except must catch it, log the parse error, and
|
||||
# return cleanly.
|
||||
malformed = bytearray.fromhex("11100000")
|
||||
|
||||
with caplog.at_level(logging.ERROR, logger="meshcore"):
|
||||
await reader.handle_rx(malformed) # must not raise
|
||||
|
||||
error_records = [r for r in caplog.records if "handle_rx parse error" in r.message]
|
||||
assert error_records, (
|
||||
f"Expected an error log containing 'handle_rx parse error'; "
|
||||
f"got: {[r.message for r in caplog.records]}"
|
||||
)
|
||||
# Traceback should be present in the log message
|
||||
assert "Traceback" in error_records[0].message, (
|
||||
"Umbrella log message must include a traceback"
|
||||
)
|
||||
# No CHANNEL_MSG_RECV event should have been dispatched
|
||||
assert not any(e.type == EventType.CHANNEL_MSG_RECV for e in dispatcher.events)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_battery_short_frame_omits_storage_fields():
|
||||
"""Short BATTERY frame must not silently yield zero used_kb/total_kb."""
|
||||
dispatcher = _CapturingDispatcher()
|
||||
reader = MessageReader(dispatcher)
|
||||
|
||||
# 3-byte BATTERY frame: type 0x0c + 2 level bytes (no storage tail).
|
||||
# Pre-fix the `len(data) > 3` gate would have let any frame >= 4 bytes
|
||||
# through, producing a BATTERY event with bogus zero used_kb/total_kb
|
||||
# because io.BytesIO.read() returns short data without raising.
|
||||
# Post-fix (`len(data) >= 11`) the storage fields are skipped entirely.
|
||||
short_battery = bytearray.fromhex("0c8000")
|
||||
|
||||
await reader.handle_rx(short_battery)
|
||||
|
||||
battery_events = [e for e in dispatcher.events if e.type == EventType.BATTERY]
|
||||
assert len(battery_events) == 1, (
|
||||
f"Expected exactly one BATTERY event, got {len(battery_events)}"
|
||||
)
|
||||
payload = battery_events[0].payload
|
||||
assert payload["level"] == 0x0080, f"Unexpected level: {payload['level']}"
|
||||
assert "used_kb" not in payload, (
|
||||
"Short BATTERY frame must not include used_kb (would be a silent zero)"
|
||||
)
|
||||
assert "total_kb" not in payload, (
|
||||
"Short BATTERY frame must not include total_kb (would be a silent zero)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_battery_too_short_for_level(caplog):
|
||||
"""BATTERY frame shorter than 3 bytes must be dropped entirely (Option B).
|
||||
|
||||
A 1-byte frame (just the packet-type byte 0x0c, no level bytes) would cause
|
||||
dbuf.read(2) to return b"" and int.from_bytes(b"", ...) to silently yield 0.
|
||||
The fix adds an early return with a debug log.
|
||||
"""
|
||||
dispatcher = _CapturingDispatcher()
|
||||
reader = MessageReader(dispatcher)
|
||||
|
||||
# 1-byte BATTERY frame: only the type byte, no level payload.
|
||||
too_short = bytearray.fromhex("0c")
|
||||
|
||||
with caplog.at_level(logging.DEBUG, logger="meshcore"):
|
||||
await reader.handle_rx(too_short)
|
||||
|
||||
battery_events = [e for e in dispatcher.events if e.type == EventType.BATTERY]
|
||||
assert len(battery_events) == 0, (
|
||||
"BATTERY frame shorter than 3 bytes must not dispatch an event"
|
||||
)
|
||||
debug_records = [
|
||||
r for r in caplog.records if "BATTERY frame too short" in r.message
|
||||
]
|
||||
assert debug_records, "Expected a debug log about the short BATTERY frame"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_response_short_frame_skipped(caplog):
|
||||
"""Short STATUS_RESPONSE push frame must be skipped, not parsed with bogus zeros."""
|
||||
dispatcher = _CapturingDispatcher()
|
||||
reader = MessageReader(dispatcher)
|
||||
|
||||
# 30-byte STATUS_RESPONSE push frame, well below the 60-byte minimum.
|
||||
# First byte is the type (0x87 = PacketType.STATUS_RESPONSE), the rest
|
||||
# is arbitrary filler. parse_status with offset=8 reads up through
|
||||
# data[56:60], so anything < 60 bytes would yield short reads and
|
||||
# silent zero values pre-fix.
|
||||
short_status = bytearray([0x87] + [0xAA] * 29)
|
||||
assert len(short_status) == 30
|
||||
|
||||
with caplog.at_level(logging.DEBUG, logger="meshcore"):
|
||||
await reader.handle_rx(short_status)
|
||||
|
||||
status_events = [e for e in dispatcher.events if e.type == EventType.STATUS_RESPONSE]
|
||||
assert len(status_events) == 0, (
|
||||
"Short STATUS_RESPONSE push frame must not dispatch a parsed event"
|
||||
)
|
||||
assert any(
|
||||
"STATUS_RESPONSE push frame too short" in r.message for r in caplog.records
|
||||
), "Expected a debug log line for short STATUS_RESPONSE frames"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_packet_payload_txt_type_decodes_high_bits():
|
||||
"""txt_type must decode the high 6 bits of byte 4, not always be 0."""
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Hash import HMAC, SHA256
|
||||
from meshcore.meshcore_parser import MeshcorePacketParser
|
||||
|
||||
parser = MeshcorePacketParser()
|
||||
parser.decrypt_channels = True
|
||||
|
||||
# Set up a synthetic channel with a known 16-byte AES key. Direct dict
|
||||
# assignment matches how the parser stores channels (newChannel is async
|
||||
# and serves the same purpose).
|
||||
channel_secret = b"\x01" * 16
|
||||
channel_hash_byte = 0xAB
|
||||
parser.channels[0] = {
|
||||
"channel_idx": 0,
|
||||
"channel_name": "test",
|
||||
"channel_hash": "ab",
|
||||
"channel_secret": channel_secret,
|
||||
}
|
||||
|
||||
# 16-byte plaintext (one AES block):
|
||||
# bytes 0-3 = sender_timestamp (little-endian)
|
||||
# byte 4 = (txt_type << 2) | attempt
|
||||
# bytes 5-15 = message + null padding
|
||||
# Pick txt_type=5, attempt=1 → byte 4 = (5 << 2) | 1 = 0x15.
|
||||
# Pre-fix uncrypted[4:4] is empty so txt_type would be 0;
|
||||
# post-fix uncrypted[4:5] yields 0x15 >> 2 = 5.
|
||||
plaintext = b"\x00\x00\x00\x00\x15hello\x00\x00\x00\x00\x00\x00"
|
||||
assert len(plaintext) == 16
|
||||
|
||||
encrypted = AES.new(channel_secret, AES.MODE_ECB).encrypt(plaintext)
|
||||
|
||||
# cipher_mac = first 2 bytes of HMAC-SHA256(channel_secret, encrypted)
|
||||
h = HMAC.new(channel_secret, digestmod=SHA256)
|
||||
h.update(encrypted)
|
||||
cipher_mac = h.digest()[:2]
|
||||
|
||||
# pkt_payload layout: 1-byte chan_hash + 2-byte cipher_mac + ciphertext
|
||||
pkt_payload = bytes([channel_hash_byte]) + cipher_mac + encrypted
|
||||
|
||||
# parsePacketPayload expects the full payload buffer:
|
||||
# header byte (route_type=1 DIRECT, payload_type=5 channel, ver=0)
|
||||
# path_byte (path_len=0, path_hash_size=1) → 0x00
|
||||
# pkt_payload
|
||||
header = 0x15 # route_type=1, payload_type=5, payload_ver=0
|
||||
path_byte = 0x00
|
||||
payload = bytes([header, path_byte]) + pkt_payload
|
||||
|
||||
log_data = await parser.parsePacketPayload(payload, log_data={})
|
||||
|
||||
assert log_data["payload_type"] == 0x05
|
||||
assert "txt_type" in log_data, (
|
||||
f"txt_type missing from log_data — channel decrypt path was not reached. "
|
||||
f"log_data keys: {list(log_data.keys())}"
|
||||
)
|
||||
assert log_data["txt_type"] == 5, (
|
||||
f"Expected txt_type=5, got {log_data['txt_type']}"
|
||||
)
|
||||
assert log_data["attempt"] == 1, (
|
||||
f"Expected attempt=1, got {log_data['attempt']}"
|
||||
)
|
||||
238
tests/unit/test_transport_symmetry.py
Normal file
238
tests/unit/test_transport_symmetry.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Verification tests for transport symmetry fixes.
|
||||
|
||||
Covers: send symmetry across transports, serial disconnect callback on
|
||||
transport-lost, serial connect timeout, oversize-frame return, BLE
|
||||
disconnect-callback re-registration, BLE pairing failure re-raise,
|
||||
TCP counter per frame not per segment.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from meshcore.tcp_cx import TCPConnection
|
||||
from meshcore.serial_cx import SerialConnection
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RecordingReader:
|
||||
"""Minimal reader mock that records dispatched frames."""
|
||||
def __init__(self):
|
||||
self.frames = []
|
||||
|
||||
async def handle_rx(self, data):
|
||||
self.frames.append(bytes(data))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TCP send() wraps transport.write in try/except
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_send_write_error_fires_disconnect():
|
||||
"""TCP: OSError during transport.write fires _disconnect_callback."""
|
||||
cx = TCPConnection("127.0.0.1", 5000)
|
||||
cb = AsyncMock()
|
||||
cx.set_disconnect_callback(cb)
|
||||
|
||||
mock_transport = MagicMock()
|
||||
mock_transport.write.side_effect = OSError("Broken pipe")
|
||||
cx.transport = mock_transport
|
||||
cx._send_count = 0
|
||||
cx._receive_count = 0
|
||||
|
||||
await cx.send(b"\x01\x02\x03")
|
||||
|
||||
cb.assert_awaited_once()
|
||||
reason = cb.call_args[0][0]
|
||||
assert "tcp_write_failed" in reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serial send() fires disconnect on transport-lost and write error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serial_send_no_transport_fires_disconnect():
|
||||
"""Serial: send() on None transport fires _disconnect_callback ."""
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
cb = AsyncMock()
|
||||
cx.set_disconnect_callback(cb)
|
||||
cx.transport = None
|
||||
|
||||
await cx.send(b"\x01")
|
||||
|
||||
cb.assert_awaited_once()
|
||||
reason = cb.call_args[0][0]
|
||||
assert reason == "serial_transport_lost"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serial_send_write_error_fires_disconnect():
|
||||
"""Serial: OSError during transport.write fires _disconnect_callback."""
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
cb = AsyncMock()
|
||||
cx.set_disconnect_callback(cb)
|
||||
|
||||
mock_transport = MagicMock()
|
||||
mock_transport.write.side_effect = OSError("Device not configured")
|
||||
cx.transport = mock_transport
|
||||
|
||||
await cx.send(b"\x01")
|
||||
|
||||
cb.assert_awaited_once()
|
||||
reason = cb.call_args[0][0]
|
||||
assert "serial_write_failed" in reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BLE send() fires disconnect on transport-lost and write error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ble_send_no_client_fires_disconnect():
|
||||
"""BLE: send() with no client fires _disconnect_callback."""
|
||||
# Can't import BLEConnection directly if bleak isn't installed,
|
||||
# so test via dynamic import with a guard.
|
||||
try:
|
||||
from meshcore.ble_cx import BLEConnection
|
||||
except ImportError:
|
||||
pytest.skip("bleak not installed")
|
||||
|
||||
# BLEConnection.__init__ checks BLEAK_AVAILABLE; patch it
|
||||
with patch("meshcore.ble_cx.BLEAK_AVAILABLE", True), \
|
||||
patch("meshcore.ble_cx.BleakClient", MagicMock()):
|
||||
cx = BLEConnection.__new__(BLEConnection)
|
||||
cx.client = None
|
||||
cx._user_provided_client = None
|
||||
cx._user_provided_address = None
|
||||
cx._user_provided_device = None
|
||||
cx.address = None
|
||||
cx.device = None
|
||||
cx.pin = None
|
||||
cx.rx_char = None
|
||||
cb = AsyncMock()
|
||||
cx._disconnect_callback = cb
|
||||
|
||||
result = await cx.send(b"\x01")
|
||||
|
||||
assert result is False
|
||||
cb.assert_awaited_once()
|
||||
reason = cb.call_args[0][0]
|
||||
assert reason == "ble_transport_lost"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serial connect() times out if connection_made never fires
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serial_connect_timeout():
|
||||
"""Serial: connect() raises TimeoutError if connection_made never fires."""
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
|
||||
# Mock create_serial_connection to do nothing (never fires connection_made)
|
||||
async def mock_create(*args, **kwargs):
|
||||
return (MagicMock(), MagicMock())
|
||||
|
||||
with patch("meshcore.serial_cx.serial_asyncio.create_serial_connection",
|
||||
side_effect=mock_create):
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await cx.connect(timeout=0.1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Oversize frame resets state and returns without dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_oversize_frame_empty_data_returns():
|
||||
"""TCP: oversize header with no trailing data returns without dispatch."""
|
||||
cx = TCPConnection("127.0.0.1", 5000)
|
||||
reader = RecordingReader()
|
||||
cx.set_reader(reader)
|
||||
|
||||
# Build a frame header with size > 300 and no payload data after header
|
||||
# Header: 0x3e + 2-byte LE size (e.g. 500 = 0x01F4)
|
||||
header = b"\x3e" + (500).to_bytes(2, "little")
|
||||
cx.handle_rx(header)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# No frames should be dispatched, and state should be reset
|
||||
assert reader.frames == []
|
||||
assert cx.header == b""
|
||||
assert cx.inframe == b""
|
||||
assert cx.frame_expected_size == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serial_oversize_frame_empty_data_returns():
|
||||
"""Serial: oversize header with no trailing data returns without dispatch."""
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
reader = RecordingReader()
|
||||
cx.set_reader(reader)
|
||||
|
||||
header = b"\x3e" + (500).to_bytes(2, "little")
|
||||
cx.handle_rx(header)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert reader.frames == []
|
||||
assert cx.header == b""
|
||||
assert cx.inframe == b""
|
||||
assert cx.frame_expected_size == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TCP receive counter increments per MeshCore frame, not per TCP segment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_receive_count_per_frame_not_per_segment():
|
||||
"""TCP: _receive_count increments per completed frame, not per data_received call."""
|
||||
cx = TCPConnection("127.0.0.1", 5000)
|
||||
reader = RecordingReader()
|
||||
cx.set_reader(reader)
|
||||
cx._receive_count = 0
|
||||
|
||||
# Build a 4-byte payload frame
|
||||
payload = b"\xAA\xBB\xCC\xDD"
|
||||
frame = b"\x3e" + len(payload).to_bytes(2, "little") + payload
|
||||
|
||||
# Split the frame into 3 TCP segments (simulating fragmentation)
|
||||
protocol = TCPConnection.MCClientProtocol(cx)
|
||||
protocol.data_received(frame[:2]) # partial header
|
||||
protocol.data_received(frame[2:5]) # rest of header + 2 bytes payload
|
||||
protocol.data_received(frame[5:]) # remaining payload
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# 3 data_received calls but only 1 completed frame
|
||||
assert cx._receive_count == 1
|
||||
assert reader.frames == [payload]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_multiple_frames_count_correctly():
|
||||
"""TCP: two complete frames in separate segments → _receive_count == 2."""
|
||||
cx = TCPConnection("127.0.0.1", 5000)
|
||||
reader = RecordingReader()
|
||||
cx.set_reader(reader)
|
||||
cx._receive_count = 0
|
||||
|
||||
payload1 = b"\x01\x02"
|
||||
frame1 = b"\x3e" + len(payload1).to_bytes(2, "little") + payload1
|
||||
payload2 = b"\x03\x04\x05"
|
||||
frame2 = b"\x3e" + len(payload2).to_bytes(2, "little") + payload2
|
||||
|
||||
protocol = TCPConnection.MCClientProtocol(cx)
|
||||
protocol.data_received(frame1)
|
||||
protocol.data_received(frame2)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert cx._receive_count == 2
|
||||
assert reader.frames == [payload1, payload2]
|
||||
Reference in New Issue
Block a user