mirror of
https://github.com/meshcore-dev/meshcore_py.git
synced 2026-06-11 11:56:18 +00:00
Merge branch 'main' into fix/reader-parser-crash-safety
This commit is contained in:
@@ -51,6 +51,14 @@ class BLEConnection:
|
|||||||
self.pin = pin
|
self.pin = pin
|
||||||
self.rx_char = None
|
self.rx_char = None
|
||||||
self._disconnect_callback = None
|
self._disconnect_callback = None
|
||||||
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
def _spawn_background(self, coro) -> asyncio.Task:
|
||||||
|
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||||
|
task = asyncio.create_task(coro)
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
return task
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
"""
|
"""
|
||||||
@@ -116,9 +124,12 @@ class BLEConnection:
|
|||||||
await self.client.pair()
|
await self.client.pair()
|
||||||
logger.info("BLE pairing successful")
|
logger.info("BLE pairing successful")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"BLE pairing failed: {e}")
|
logger.error(f"BLE pairing failed: {e}")
|
||||||
# Don't fail the connection if pairing fails, as the device
|
# A failed pairing leaves the transport in a half-usable
|
||||||
# might already be paired or not require pairing
|
# state — re-raise so the caller gets a clean failure
|
||||||
|
# instead of a silently degraded connection.
|
||||||
|
await self.client.disconnect()
|
||||||
|
raise
|
||||||
|
|
||||||
except BleakDeviceNotFoundError:
|
except BleakDeviceNotFoundError:
|
||||||
return None
|
return None
|
||||||
@@ -154,8 +165,19 @@ class BLEConnection:
|
|||||||
self.client = self._user_provided_client
|
self.client = self._user_provided_client
|
||||||
self.device = self._user_provided_device
|
self.device = self._user_provided_device
|
||||||
|
|
||||||
|
# Re-register disconnect callback on the reset client so subsequent
|
||||||
|
# disconnects after a reconnect cycle are still detected.
|
||||||
|
if self.client is not None and hasattr(self.client, 'set_disconnected_callback'):
|
||||||
|
try:
|
||||||
|
self.client.set_disconnected_callback(self.handle_disconnect)
|
||||||
|
except Exception:
|
||||||
|
# set_disconnected_callback may not be available on all bleak
|
||||||
|
# versions; the next connect() call will re-create the client
|
||||||
|
# with the callback anyway.
|
||||||
|
pass
|
||||||
|
|
||||||
if self._disconnect_callback:
|
if self._disconnect_callback:
|
||||||
asyncio.create_task(self._disconnect_callback("ble_disconnect"))
|
self._spawn_background(self._disconnect_callback("ble_disconnect"))
|
||||||
|
|
||||||
def set_disconnect_callback(self, callback):
|
def set_disconnect_callback(self, callback):
|
||||||
"""Set callback to handle disconnections."""
|
"""Set callback to handle disconnections."""
|
||||||
@@ -166,16 +188,24 @@ class BLEConnection:
|
|||||||
|
|
||||||
def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray):
|
def handle_rx(self, _: BleakGATTCharacteristic, data: bytearray):
|
||||||
if self.reader is not None:
|
if self.reader is not None:
|
||||||
asyncio.create_task(self.reader.handle_rx(data))
|
self._spawn_background(self.reader.handle_rx(data))
|
||||||
|
|
||||||
async def send(self, data):
|
async def send(self, data):
|
||||||
if not self.client:
|
if not self.client:
|
||||||
logger.error("Client is not connected")
|
logger.error("Client is not connected")
|
||||||
|
if self._disconnect_callback:
|
||||||
|
await self._disconnect_callback("ble_transport_lost")
|
||||||
return False
|
return False
|
||||||
if not self.rx_char:
|
if not self.rx_char:
|
||||||
logger.error("RX characteristic not found")
|
logger.error("RX characteristic not found")
|
||||||
return False
|
return False
|
||||||
await self.client.write_gatt_char(self.rx_char, bytes(data), response=True)
|
try:
|
||||||
|
await self.client.write_gatt_char(self.rx_char, bytes(data), response=True)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"BLE write failed: {exc}")
|
||||||
|
if self._disconnect_callback:
|
||||||
|
await self._disconnect_callback(f"ble_write_failed: {exc}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""Disconnect from the BLE device."""
|
"""Disconnect from the BLE device."""
|
||||||
|
|||||||
@@ -58,17 +58,32 @@ def _validate_destination(dst: DestinationType, prefix_length: int = 6) -> bytes
|
|||||||
|
|
||||||
|
|
||||||
class CommandHandlerBase:
|
class CommandHandlerBase:
|
||||||
|
"""Base class for command handlers.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The internal ``asyncio.Lock`` is created lazily on first access
|
||||||
|
so that it binds to the correct running event loop (required for
|
||||||
|
Python 3.9/3.10 compatibility).
|
||||||
|
"""
|
||||||
|
|
||||||
DEFAULT_TIMEOUT = 5.0
|
DEFAULT_TIMEOUT = 5.0
|
||||||
|
|
||||||
def __init__(self, default_timeout: Optional[float] = None):
|
def __init__(self, default_timeout: Optional[float] = None):
|
||||||
self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None
|
self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None
|
||||||
self._reader: Optional[MessageReader] = None
|
self._reader: Optional[MessageReader] = None
|
||||||
self.dispatcher: Optional[EventDispatcher] = None
|
self.dispatcher: Optional[EventDispatcher] = None
|
||||||
self._mesh_request_lock = asyncio.Lock()
|
self.__mesh_request_lock: Optional[asyncio.Lock] = None
|
||||||
self.default_timeout = (
|
self.default_timeout = (
|
||||||
default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT
|
default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _mesh_request_lock(self) -> asyncio.Lock:
|
||||||
|
"""Lazy-init lock so it binds to the running loop, not import-time."""
|
||||||
|
if self.__mesh_request_lock is None:
|
||||||
|
self.__mesh_request_lock = asyncio.Lock()
|
||||||
|
return self.__mesh_request_lock
|
||||||
|
|
||||||
def set_connection(self, connection: Any) -> None:
|
def set_connection(self, connection: Any) -> None:
|
||||||
async def sender(data: bytes) -> None:
|
async def sender(data: bytes) -> None:
|
||||||
await connection.send(data)
|
await connection.send(data)
|
||||||
@@ -90,6 +105,14 @@ class CommandHandlerBase:
|
|||||||
expected_events: Optional[Union[EventType, List[EventType]]] = None,
|
expected_events: Optional[Union[EventType, List[EventType]]] = None,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
) -> Event:
|
) -> Event:
|
||||||
|
"""Wait for the first of *expected_events* to arrive.
|
||||||
|
|
||||||
|
Returns the first matched ``Event``. When ``EventType.ERROR`` is
|
||||||
|
among the expected types, the caller **must** check
|
||||||
|
``result.is_error()`` before accessing command-specific payload
|
||||||
|
keys — an ERROR payload is ``{"reason": "..."}`` and will
|
||||||
|
``KeyError`` on any other key.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Convert single event to list if needed
|
# Convert single event to list if needed
|
||||||
if not isinstance(expected_events, list):
|
if not isinstance(expected_events, list):
|
||||||
@@ -129,9 +152,6 @@ class CommandHandlerBase:
|
|||||||
logger.debug(f"Command error: {e}")
|
logger.debug(f"Command error: {e}")
|
||||||
return Event(EventType.ERROR, {"error": str(e)})
|
return Event(EventType.ERROR, {"error": str(e)})
|
||||||
|
|
||||||
return Event(EventType.ERROR, {})
|
|
||||||
|
|
||||||
|
|
||||||
async def send(
|
async def send(
|
||||||
self,
|
self,
|
||||||
data: bytes,
|
data: bytes,
|
||||||
@@ -151,7 +171,14 @@ class CommandHandlerBase:
|
|||||||
timeout: Timeout in seconds, or None to use default_timeout
|
timeout: Timeout in seconds, or None to use default_timeout
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Event: The full event object that was received in response to the command
|
Event: The full event object that was received in response to
|
||||||
|
the command.
|
||||||
|
|
||||||
|
Important:
|
||||||
|
When ``EventType.ERROR`` is included in *expected_events*, the
|
||||||
|
returned event may be an error response. Callers **must**
|
||||||
|
check ``result.is_error()`` before accessing command-specific
|
||||||
|
payload keys to avoid ``KeyError``.
|
||||||
"""
|
"""
|
||||||
if not self.dispatcher:
|
if not self.dispatcher:
|
||||||
raise RuntimeError("Dispatcher not set, cannot send commands")
|
raise RuntimeError("Dispatcher not set, cannot send commands")
|
||||||
@@ -170,7 +197,7 @@ class CommandHandlerBase:
|
|||||||
futures: List[asyncio.Future] = []
|
futures: List[asyncio.Future] = []
|
||||||
subscriptions = []
|
subscriptions = []
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_running_loop()
|
||||||
for event_type in expected_events:
|
for event_type in expected_events:
|
||||||
future = loop.create_future()
|
future = loop.create_future()
|
||||||
|
|
||||||
@@ -266,6 +293,7 @@ class CommandHandlerBase:
|
|||||||
contact = self._get_contact_by_prefix(dst_bytes.hex()) # need a contact for return path
|
contact = self._get_contact_by_prefix(dst_bytes.hex()) # need a contact for return path
|
||||||
if contact is None:
|
if contact is None:
|
||||||
logger.error("No contact found")
|
logger.error("No contact found")
|
||||||
|
return Event(EventType.ERROR, {"reason": "contact_not_found"})
|
||||||
|
|
||||||
zero_hop = False
|
zero_hop = False
|
||||||
if contact["out_path_len"] == -1:
|
if contact["out_path_len"] == -1:
|
||||||
|
|||||||
@@ -185,6 +185,24 @@ class ContactCommands(CommandHandlerBase):
|
|||||||
data = b"\x3B"
|
data = b"\x3B"
|
||||||
return await self.send(data, [EventType.AUTOADD_CONFIG, EventType.ERROR])
|
return await self.send(data, [EventType.AUTOADD_CONFIG, EventType.ERROR])
|
||||||
|
|
||||||
|
async def get_contact_by_key(self, pubkey: bytes) -> Event:
|
||||||
|
"""N09: Retrieve a single contact by its public key (CMD 30).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pubkey: 32-byte public key of the contact.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Event with the contact data (same format as CONTACT/NEXT_CONTACT),
|
||||||
|
or ERROR if not found.
|
||||||
|
"""
|
||||||
|
if not isinstance(pubkey, (bytes, bytearray)):
|
||||||
|
raise TypeError("pubkey must be bytes-like")
|
||||||
|
# Truncate or pad to 32 bytes
|
||||||
|
key_bytes = bytes(pubkey[:32])
|
||||||
|
logger.debug(f"Getting contact by key: {key_bytes.hex()}")
|
||||||
|
data = b"\x1e" + key_bytes
|
||||||
|
return await self.send(data, [EventType.NEXT_CONTACT, EventType.ERROR])
|
||||||
|
|
||||||
async def get_advert_path(self, key: DestinationType) -> Event:
|
async def get_advert_path(self, key: DestinationType) -> Event:
|
||||||
key_bytes = _validate_destination(key, prefix_length=32)
|
key_bytes = _validate_destination(key, prefix_length=32)
|
||||||
logger.debug(f"getting advert path for: {key} {key_bytes.hex()}")
|
logger.debug(f"getting advert path for: {key} {key_bytes.hex()}")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from hashlib import sha256
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..events import Event, EventType
|
from ..events import Event, EventType
|
||||||
|
from ..packets import CommandType
|
||||||
from .base import CommandHandlerBase, DestinationType, _validate_destination
|
from .base import CommandHandlerBase, DestinationType, _validate_destination
|
||||||
|
|
||||||
logger = logging.getLogger("meshcore")
|
logger = logging.getLogger("meshcore")
|
||||||
@@ -13,7 +14,7 @@ class DeviceCommands(CommandHandlerBase):
|
|||||||
async def send_appstart(self) -> Event:
|
async def send_appstart(self) -> Event:
|
||||||
logger.debug("Sending appstart command")
|
logger.debug("Sending appstart command")
|
||||||
b1 = bytearray(b"\x01\x03 mccli")
|
b1 = bytearray(b"\x01\x03 mccli")
|
||||||
return await self.send(b1, [EventType.SELF_INFO])
|
return await self.send(b1, [EventType.SELF_INFO, EventType.ERROR])
|
||||||
|
|
||||||
async def send_device_query(self) -> Event:
|
async def send_device_query(self) -> Event:
|
||||||
logger.debug("Sending device query command")
|
logger.debug("Sending device query command")
|
||||||
@@ -129,32 +130,50 @@ class DeviceCommands(CommandHandlerBase):
|
|||||||
return await self.send(data, [EventType.OK, EventType.ERROR])
|
return await self.send(data, [EventType.OK, EventType.ERROR])
|
||||||
|
|
||||||
async def set_telemetry_mode_base(self, telemetry_mode_base: int) -> Event:
|
async def set_telemetry_mode_base(self, telemetry_mode_base: int) -> Event:
|
||||||
infos = (await self.send_appstart()).payload
|
result = await self.send_appstart()
|
||||||
|
if result.is_error():
|
||||||
|
return result
|
||||||
|
infos = result.payload
|
||||||
infos["telemetry_mode_base"] = telemetry_mode_base
|
infos["telemetry_mode_base"] = telemetry_mode_base
|
||||||
return await self.set_other_params_from_infos(infos)
|
return await self.set_other_params_from_infos(infos)
|
||||||
|
|
||||||
async def set_telemetry_mode_loc(self, telemetry_mode_loc: int) -> Event:
|
async def set_telemetry_mode_loc(self, telemetry_mode_loc: int) -> Event:
|
||||||
infos = (await self.send_appstart()).payload
|
result = await self.send_appstart()
|
||||||
|
if result.is_error():
|
||||||
|
return result
|
||||||
|
infos = result.payload
|
||||||
infos["telemetry_mode_loc"] = telemetry_mode_loc
|
infos["telemetry_mode_loc"] = telemetry_mode_loc
|
||||||
return await self.set_other_params_from_infos(infos)
|
return await self.set_other_params_from_infos(infos)
|
||||||
|
|
||||||
async def set_telemetry_mode_env(self, telemetry_mode_env: int) -> Event:
|
async def set_telemetry_mode_env(self, telemetry_mode_env: int) -> Event:
|
||||||
infos = (await self.send_appstart()).payload
|
result = await self.send_appstart()
|
||||||
|
if result.is_error():
|
||||||
|
return result
|
||||||
|
infos = result.payload
|
||||||
infos["telemetry_mode_env"] = telemetry_mode_env
|
infos["telemetry_mode_env"] = telemetry_mode_env
|
||||||
return await self.set_other_params_from_infos(infos)
|
return await self.set_other_params_from_infos(infos)
|
||||||
|
|
||||||
async def set_manual_add_contacts(self, manual_add_contacts: bool) -> Event:
|
async def set_manual_add_contacts(self, manual_add_contacts: bool) -> Event:
|
||||||
infos = (await self.send_appstart()).payload
|
result = await self.send_appstart()
|
||||||
|
if result.is_error():
|
||||||
|
return result
|
||||||
|
infos = result.payload
|
||||||
infos["manual_add_contacts"] = manual_add_contacts
|
infos["manual_add_contacts"] = manual_add_contacts
|
||||||
return await self.set_other_params_from_infos(infos)
|
return await self.set_other_params_from_infos(infos)
|
||||||
|
|
||||||
async def set_advert_loc_policy(self, advert_loc_policy: int) -> Event:
|
async def set_advert_loc_policy(self, advert_loc_policy: int) -> Event:
|
||||||
infos = (await self.send_appstart()).payload
|
result = await self.send_appstart()
|
||||||
|
if result.is_error():
|
||||||
|
return result
|
||||||
|
infos = result.payload
|
||||||
infos["adv_loc_policy"] = advert_loc_policy
|
infos["adv_loc_policy"] = advert_loc_policy
|
||||||
return await self.set_other_params_from_infos(infos)
|
return await self.set_other_params_from_infos(infos)
|
||||||
|
|
||||||
async def set_multi_acks(self, multi_acks: int) -> Event:
|
async def set_multi_acks(self, multi_acks: int) -> Event:
|
||||||
infos = (await self.send_appstart()).payload
|
result = await self.send_appstart()
|
||||||
|
if result.is_error():
|
||||||
|
return result
|
||||||
|
infos = result.payload
|
||||||
infos["multi_acks"] = multi_acks
|
infos["multi_acks"] = multi_acks
|
||||||
return await self.set_other_params_from_infos(infos)
|
return await self.set_other_params_from_infos(infos)
|
||||||
|
|
||||||
@@ -273,20 +292,89 @@ class DeviceCommands(CommandHandlerBase):
|
|||||||
|
|
||||||
return await self.sign_finish(timeout=timeout, data_size=len(data))
|
return await self.sign_finish(timeout=timeout, data_size=len(data))
|
||||||
|
|
||||||
|
async def has_connection(self) -> Event:
|
||||||
|
"""N09: Check if the device has an active connection (CMD 28).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Event with a 1-byte response indicating connection status,
|
||||||
|
or ERROR.
|
||||||
|
"""
|
||||||
|
logger.debug("Checking device connection status")
|
||||||
|
return await self.send(b"\x1c", [EventType.OK, EventType.ERROR])
|
||||||
|
|
||||||
|
async def get_tuning(self) -> Event:
|
||||||
|
"""N03/N09: Request current tuning parameters (CMD_GET_TUNING_PARAMS = 43).
|
||||||
|
|
||||||
|
Firmware responds with RESP_CODE_TUNING_PARAMS (23): 9 bytes containing
|
||||||
|
rx_delay (4 bytes LE) and airtime_factor (4 bytes LE).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Event of type TUNING_PARAMS with rx_delay and airtime_factor,
|
||||||
|
or ERROR.
|
||||||
|
"""
|
||||||
|
logger.debug("Getting tuning parameters")
|
||||||
|
return await self.send(b"\x2b", [EventType.TUNING_PARAMS, EventType.ERROR])
|
||||||
|
|
||||||
|
async def request_factory_reset(self) -> str:
|
||||||
|
"""N09: Request a factory reset token (step 1 of 2).
|
||||||
|
|
||||||
|
This method returns a confirmation token string. Pass it to
|
||||||
|
``confirm_factory_reset(token)`` to actually execute the reset.
|
||||||
|
The two-step pattern is a Python-side safety measure; the firmware
|
||||||
|
itself has no token verification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A confirmation token string to pass to confirm_factory_reset().
|
||||||
|
"""
|
||||||
|
import secrets
|
||||||
|
token = secrets.token_hex(8)
|
||||||
|
logger.warning(
|
||||||
|
"Factory reset requested. Call confirm_factory_reset('%s') to proceed. "
|
||||||
|
"This will ERASE ALL DATA on the device.", token
|
||||||
|
)
|
||||||
|
# Store the token on the instance for validation
|
||||||
|
self._factory_reset_token = token
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def confirm_factory_reset(self, token: str) -> Event:
|
||||||
|
"""N09: Execute factory reset after token confirmation (step 2 of 2).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The token returned by request_factory_reset().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Event with OK or ERROR.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the token does not match.
|
||||||
|
"""
|
||||||
|
expected = getattr(self, "_factory_reset_token", None)
|
||||||
|
if expected is None or token != expected:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid or expired factory reset token. "
|
||||||
|
"Call request_factory_reset() first."
|
||||||
|
)
|
||||||
|
self._factory_reset_token = None # Consume the token
|
||||||
|
logger.warning("Executing factory reset — all device data will be erased")
|
||||||
|
return await self.send(b"\x33", [EventType.OK, EventType.ERROR])
|
||||||
|
|
||||||
async def get_stats_core(self) -> Event:
|
async def get_stats_core(self) -> Event:
|
||||||
logger.debug("Getting core statistics")
|
logger.debug("Getting core statistics")
|
||||||
# CMD_GET_STATS (56) + STATS_TYPE_CORE (0)
|
# R04: Use CommandType enum instead of literal bytes
|
||||||
return await self.send(b"\x38\x00", [EventType.STATS_CORE, EventType.ERROR])
|
cmd = bytes([CommandType.GET_STATS.value, 0x00]) # GET_STATS + STATS_TYPE_CORE
|
||||||
|
return await self.send(cmd, [EventType.STATS_CORE, EventType.ERROR])
|
||||||
|
|
||||||
async def get_stats_radio(self) -> Event:
|
async def get_stats_radio(self) -> Event:
|
||||||
logger.debug("Getting radio statistics")
|
logger.debug("Getting radio statistics")
|
||||||
# CMD_GET_STATS (56) + STATS_TYPE_RADIO (1)
|
# R04: Use CommandType enum instead of literal bytes
|
||||||
return await self.send(b"\x38\x01", [EventType.STATS_RADIO, EventType.ERROR])
|
cmd = bytes([CommandType.GET_STATS.value, 0x01]) # GET_STATS + STATS_TYPE_RADIO
|
||||||
|
return await self.send(cmd, [EventType.STATS_RADIO, EventType.ERROR])
|
||||||
|
|
||||||
async def get_stats_packets(self) -> Event:
|
async def get_stats_packets(self) -> Event:
|
||||||
logger.debug("Getting packet statistics")
|
logger.debug("Getting packet statistics")
|
||||||
# CMD_GET_STATS (56) + STATS_TYPE_PACKETS (2)
|
# R04: Use CommandType enum instead of literal bytes
|
||||||
return await self.send(b"\x38\x02", [EventType.STATS_PACKETS, EventType.ERROR])
|
cmd = bytes([CommandType.GET_STATS.value, 0x02]) # GET_STATS + STATS_TYPE_PACKETS
|
||||||
|
return await self.send(cmd, [EventType.STATS_PACKETS, EventType.ERROR])
|
||||||
|
|
||||||
async def get_allowed_repeat_freq(self) -> Event:
|
async def get_allowed_repeat_freq(self) -> Event:
|
||||||
logger.debug("Getting allowed repeat freqs")
|
logger.debug("Getting allowed repeat freqs")
|
||||||
|
|||||||
@@ -144,8 +144,12 @@ class MessagingCommands(CommandHandlerBase):
|
|||||||
logger.info(f"Retry sending msg: {attempts + 1}")
|
logger.info(f"Retry sending msg: {attempts + 1}")
|
||||||
|
|
||||||
result = await self.send_msg(dst, msg, timestamp, attempt=attempts)
|
result = await self.send_msg(dst, msg, timestamp, attempt=attempts)
|
||||||
if result.type == EventType.ERROR:
|
if result.is_error():
|
||||||
logger.error(f"⚠️ Failed to send message: {result.payload}")
|
logger.error(f"Failed to send message: {result.payload}")
|
||||||
|
attempts += 1
|
||||||
|
if flood:
|
||||||
|
flood_attempts += 1
|
||||||
|
continue
|
||||||
|
|
||||||
exp_ack = result.payload["expected_ack"].hex()
|
exp_ack = result.payload["expected_ack"].hex()
|
||||||
timeout = result.payload["suggested_timeout"] / 1000 * 1.2 if timeout==0 else timeout
|
timeout = result.payload["suggested_timeout"] / 1000 * 1.2 if timeout==0 else timeout
|
||||||
@@ -255,7 +259,7 @@ class MessagingCommands(CommandHandlerBase):
|
|||||||
elif path_hash_len == 8 :
|
elif path_hash_len == 8 :
|
||||||
flags = 3
|
flags = 3
|
||||||
else :
|
else :
|
||||||
logger.error(f"Invalid path format: {e}")
|
logger.error(f"Invalid path format: unknown path_hash_len {path_hash_len}")
|
||||||
return Event(EventType.ERROR, {"reason": "invalid_path_format"})
|
return Event(EventType.ERROR, {"reason": "invalid_path_format"})
|
||||||
else:
|
else:
|
||||||
flags = 0
|
flags = 0
|
||||||
@@ -291,12 +295,34 @@ class MessagingCommands(CommandHandlerBase):
|
|||||||
cmd_data.append(flags)
|
cmd_data.append(flags)
|
||||||
cmd_data.extend(path_bytes)
|
cmd_data.extend(path_bytes)
|
||||||
|
|
||||||
|
# N05: Firmware requires strict len > 10 (MyMesh.cpp:1620).
|
||||||
|
# When path is empty, cmd(1)+tag(4)+auth(4)+flags(1) = 10 bytes exactly,
|
||||||
|
# which is silently rejected. Pad with one zero byte to reach 11.
|
||||||
|
if len(cmd_data) <= 10:
|
||||||
|
cmd_data.append(0x00)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Sending trace: tag={tag}, auth={auth_code}, flags={flags}, path={path_bytes.hex()}"
|
f"Sending trace: tag={tag}, auth={auth_code}, flags={flags}, path={path_bytes.hex()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.send(cmd_data, [EventType.MSG_SENT, EventType.ERROR])
|
return await self.send(cmd_data, [EventType.MSG_SENT, EventType.ERROR])
|
||||||
|
|
||||||
|
async def send_raw_data(self, payload: bytes) -> Event:
|
||||||
|
"""N09: Send raw data via CMD_SEND_RAW_DATA (25).
|
||||||
|
|
||||||
|
Sends an arbitrary payload through the mesh network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: Raw bytes to send.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Event with MSG_SENT or ERROR.
|
||||||
|
"""
|
||||||
|
if not isinstance(payload, (bytes, bytearray)):
|
||||||
|
raise TypeError("payload must be bytes-like")
|
||||||
|
data = b"\x19" + bytes(payload)
|
||||||
|
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
|
||||||
|
|
||||||
async def set_flood_scope(self, scope):
|
async def set_flood_scope(self, scope):
|
||||||
if scope is None:
|
if scope is None:
|
||||||
logger.debug(f"Resetting scope")
|
logger.debug(f"Resetting scope")
|
||||||
|
|||||||
@@ -4,14 +4,23 @@ Connection manager that orchestrates reconnection logic for any connection type.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Any, Callable, Protocol
|
from typing import Optional, Any, Awaitable, Callable, Protocol
|
||||||
from .events import Event, EventType
|
from .events import Event, EventType
|
||||||
|
|
||||||
logger = logging.getLogger("meshcore")
|
logger = logging.getLogger("meshcore")
|
||||||
|
|
||||||
|
|
||||||
class ConnectionProtocol(Protocol):
|
class ConnectionProtocol(Protocol):
|
||||||
"""Protocol defining the interface that connection classes must implement."""
|
"""Protocol defining the interface that connection classes must implement.
|
||||||
|
|
||||||
|
Return contract for connect():
|
||||||
|
- On success: return a truthy value (typically an address string)
|
||||||
|
that identifies the connection. This value is included in the
|
||||||
|
CONNECTED event payload as ``connection_info``.
|
||||||
|
- On failure: return ``None`` (soft failure — triggers a retry in
|
||||||
|
``_attempt_reconnect``) **or** raise an exception (hard failure —
|
||||||
|
also triggers a retry, logged as an error).
|
||||||
|
"""
|
||||||
|
|
||||||
async def connect(self) -> Optional[Any]:
|
async def connect(self) -> Optional[Any]:
|
||||||
"""Connect and return connection info, or None if failed."""
|
"""Connect and return connection info, or None if failed."""
|
||||||
@@ -39,11 +48,13 @@ class ConnectionManager:
|
|||||||
event_dispatcher=None,
|
event_dispatcher=None,
|
||||||
auto_reconnect: bool = False,
|
auto_reconnect: bool = False,
|
||||||
max_reconnect_attempts: int = 3,
|
max_reconnect_attempts: int = 3,
|
||||||
|
reconnect_callback: Optional[Callable[[], Awaitable[None]]] = None,
|
||||||
):
|
):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.event_dispatcher = event_dispatcher
|
self.event_dispatcher = event_dispatcher
|
||||||
self.auto_reconnect = auto_reconnect
|
self.auto_reconnect = auto_reconnect
|
||||||
self.max_reconnect_attempts = max_reconnect_attempts
|
self.max_reconnect_attempts = max_reconnect_attempts
|
||||||
|
self._reconnect_callback = reconnect_callback
|
||||||
|
|
||||||
self._reconnect_attempts = 0
|
self._reconnect_attempts = 0
|
||||||
self._is_connected = False
|
self._is_connected = False
|
||||||
@@ -109,45 +120,51 @@ class ConnectionManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _attempt_reconnect(self):
|
async def _attempt_reconnect(self):
|
||||||
"""Attempt to reconnect with flat delay."""
|
"""Attempt to reconnect using an iterative loop.
|
||||||
logger.debug(
|
|
||||||
f"Attempting reconnection ({self._reconnect_attempts + 1}/{self.max_reconnect_attempts})"
|
|
||||||
)
|
|
||||||
self._reconnect_attempts += 1
|
|
||||||
|
|
||||||
# Flat 1 second delay for all attempts
|
Runs as a single persistent task for the entire reconnect session.
|
||||||
await asyncio.sleep(1)
|
Previous implementation used tail-recursion via create_task which
|
||||||
|
orphaned the running task reference — disconnect() could only cancel
|
||||||
|
the newest pointer, leaving earlier attempts in flight (F03).
|
||||||
|
"""
|
||||||
|
while self._reconnect_attempts < self.max_reconnect_attempts:
|
||||||
|
self._reconnect_attempts += 1
|
||||||
|
logger.debug(
|
||||||
|
f"Attempting reconnection ({self._reconnect_attempts}/{self.max_reconnect_attempts})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flat 1 second delay for all attempts
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.connection.connect()
|
||||||
|
if result is not None:
|
||||||
|
self._is_connected = True
|
||||||
|
self._reconnect_attempts = 0
|
||||||
|
|
||||||
|
# Invoke reconnect callback (e.g. send_appstart) if provided
|
||||||
|
if self._reconnect_callback is not None:
|
||||||
|
try:
|
||||||
|
await self._reconnect_callback()
|
||||||
|
except Exception as cb_err:
|
||||||
|
logger.warning(
|
||||||
|
f"Reconnect callback failed: {cb_err}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
result = await self.connection.connect()
|
|
||||||
if result is not None:
|
|
||||||
self._is_connected = True
|
|
||||||
self._reconnect_attempts = 0
|
|
||||||
await self._emit_event(
|
|
||||||
EventType.CONNECTED,
|
|
||||||
{"connection_info": result, "reconnected": True},
|
|
||||||
)
|
|
||||||
logger.debug("Reconnected successfully")
|
|
||||||
else:
|
|
||||||
# Reconnection failed, try again if we haven't exceeded max attempts
|
|
||||||
if self._reconnect_attempts < self.max_reconnect_attempts:
|
|
||||||
self._reconnect_task = asyncio.create_task(
|
|
||||||
self._attempt_reconnect()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await self._emit_event(
|
await self._emit_event(
|
||||||
EventType.DISCONNECTED,
|
EventType.CONNECTED,
|
||||||
{"reason": "reconnect_failed", "max_attempts_exceeded": True},
|
{"connection_info": result, "reconnected": True},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
logger.debug("Reconnected successfully")
|
||||||
logger.debug(f"Reconnection attempt failed: {e}")
|
return
|
||||||
if self._reconnect_attempts < self.max_reconnect_attempts:
|
except Exception as e:
|
||||||
self._reconnect_task = asyncio.create_task(self._attempt_reconnect())
|
logger.debug(f"Reconnection attempt failed: {e}")
|
||||||
else:
|
|
||||||
await self._emit_event(
|
# All attempts exhausted
|
||||||
EventType.DISCONNECTED,
|
await self._emit_event(
|
||||||
{"reason": f"reconnect_error: {e}", "max_attempts_exceeded": True},
|
EventType.DISCONNECTED,
|
||||||
)
|
{"reason": "reconnect_failed", "max_attempts_exceeded": True},
|
||||||
|
)
|
||||||
|
|
||||||
async def _emit_event(self, event_type: EventType, payload: dict):
|
async def _emit_event(self, event_type: EventType, payload: dict):
|
||||||
"""Emit connection events if dispatcher is available."""
|
"""Emit connection events if dispatcher is available."""
|
||||||
|
|||||||
@@ -49,6 +49,9 @@ class EventType(Enum):
|
|||||||
PATH_RESPONSE = "path_response"
|
PATH_RESPONSE = "path_response"
|
||||||
PRIVATE_KEY = "private_key"
|
PRIVATE_KEY = "private_key"
|
||||||
DISABLED = "disabled"
|
DISABLED = "disabled"
|
||||||
|
CONTACT_DELETED = "contact_deleted"
|
||||||
|
CONTACTS_FULL = "contacts_full"
|
||||||
|
TUNING_PARAMS = "tuning_params"
|
||||||
CONTROL_DATA = "control_data"
|
CONTROL_DATA = "control_data"
|
||||||
DISCOVER_RESPONSE = "discover_response"
|
DISCOVER_RESPONSE = "discover_response"
|
||||||
NEIGHBOURS_RESPONSE = "neighbours_response"
|
NEIGHBOURS_RESPONSE = "neighbours_response"
|
||||||
@@ -104,6 +107,17 @@ class Event:
|
|||||||
if kwargs:
|
if kwargs:
|
||||||
self.attributes.update(kwargs)
|
self.attributes.update(kwargs)
|
||||||
|
|
||||||
|
def is_error(self) -> bool:
|
||||||
|
"""Return True if this event represents an error response.
|
||||||
|
|
||||||
|
Callers that include ``EventType.ERROR`` in their expected-events
|
||||||
|
list **must** check ``result.is_error()`` (or ``result.type ==
|
||||||
|
EventType.ERROR``) before accessing keyed payload fields, because
|
||||||
|
an ERROR payload contains ``{"reason": "..."}`` — not the
|
||||||
|
command-specific keys the caller expects on the happy path.
|
||||||
|
"""
|
||||||
|
return self.type == EventType.ERROR
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
"""
|
"""
|
||||||
Create a copy of the event.
|
Create a copy of the event.
|
||||||
@@ -129,11 +143,28 @@ class Subscription:
|
|||||||
|
|
||||||
|
|
||||||
class EventDispatcher:
|
class EventDispatcher:
|
||||||
|
"""Event dispatch engine.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
``start()`` must be called before dispatching or processing events.
|
||||||
|
The internal ``asyncio.Queue`` is created lazily inside ``start()``
|
||||||
|
so that it binds to the correct running event loop (required for
|
||||||
|
Python 3.9/3.10 compatibility).
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.queue: asyncio.Queue[Event] = asyncio.Queue()
|
self.queue: Optional[asyncio.Queue[Event]] = None
|
||||||
self.subscriptions: List[Subscription] = []
|
self.subscriptions: List[Subscription] = []
|
||||||
self.running = False
|
self.running = False
|
||||||
self._task = None
|
self._task = None
|
||||||
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
def _spawn_background(self, coro) -> asyncio.Task:
|
||||||
|
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||||
|
task = asyncio.create_task(coro)
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
return task
|
||||||
|
|
||||||
def subscribe(
|
def subscribe(
|
||||||
self,
|
self,
|
||||||
@@ -166,6 +197,10 @@ class EventDispatcher:
|
|||||||
self.subscriptions.remove(subscription)
|
self.subscriptions.remove(subscription)
|
||||||
|
|
||||||
async def dispatch(self, event: Event):
|
async def dispatch(self, event: Event):
|
||||||
|
if self.queue is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"EventDispatcher.start() must be called before dispatching events"
|
||||||
|
)
|
||||||
await self.queue.put(event)
|
await self.queue.put(event)
|
||||||
|
|
||||||
async def _process_events(self):
|
async def _process_events(self):
|
||||||
@@ -197,7 +232,7 @@ class EventDispatcher:
|
|||||||
# returns - avoids the race where create_task schedules the callback after
|
# returns - avoids the race where create_task schedules the callback after
|
||||||
# the waiter has already timed out with done=set().
|
# the waiter has already timed out with done=set().
|
||||||
if asyncio.iscoroutinefunction(subscription.callback):
|
if asyncio.iscoroutinefunction(subscription.callback):
|
||||||
asyncio.create_task(self._execute_callback(subscription.callback, event.clone()))
|
self._spawn_background(self._execute_callback(subscription.callback, event.clone()))
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
subscription.callback(event.clone())
|
subscription.callback(event.clone())
|
||||||
@@ -220,6 +255,8 @@ class EventDispatcher:
|
|||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
if not self.running:
|
if not self.running:
|
||||||
|
if self.queue is None:
|
||||||
|
self.queue = asyncio.Queue()
|
||||||
self.running = True
|
self.running = True
|
||||||
self._task = asyncio.create_task(self._process_events())
|
self._task = asyncio.create_task(self._process_events())
|
||||||
|
|
||||||
@@ -227,7 +264,12 @@ class EventDispatcher:
|
|||||||
if self.running:
|
if self.running:
|
||||||
self.running = False
|
self.running = False
|
||||||
if self._task:
|
if self._task:
|
||||||
await self.queue.join()
|
if self.queue is not None:
|
||||||
|
await self.queue.join()
|
||||||
|
# Wait for any in-flight async callbacks to complete before
|
||||||
|
# tearing down (F07: task_done fires before callbacks finish).
|
||||||
|
if self._background_tasks:
|
||||||
|
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
try:
|
try:
|
||||||
await self._task
|
await self._task
|
||||||
|
|||||||
@@ -28,10 +28,17 @@ class MeshCore:
|
|||||||
auto_reconnect: bool = False,
|
auto_reconnect: bool = False,
|
||||||
max_reconnect_attempts: int = 3,
|
max_reconnect_attempts: int = 3,
|
||||||
):
|
):
|
||||||
# Wrap connection with ConnectionManager
|
# Wrap connection with ConnectionManager.
|
||||||
|
# The reconnect callback ensures send_appstart() runs after every
|
||||||
|
# transport-level reconnect, which is required by firmware to
|
||||||
|
# initialize the session (F02).
|
||||||
self.dispatcher = EventDispatcher()
|
self.dispatcher = EventDispatcher()
|
||||||
self.connection_manager = ConnectionManager(
|
self.connection_manager = ConnectionManager(
|
||||||
cx, self.dispatcher, auto_reconnect, max_reconnect_attempts
|
cx,
|
||||||
|
self.dispatcher,
|
||||||
|
auto_reconnect,
|
||||||
|
max_reconnect_attempts,
|
||||||
|
reconnect_callback=self._on_reconnect,
|
||||||
)
|
)
|
||||||
self.cx = self.connection_manager # For backward compatibility
|
self.cx = self.connection_manager # For backward compatibility
|
||||||
|
|
||||||
@@ -174,6 +181,15 @@ class MeshCore:
|
|||||||
return None
|
return None
|
||||||
return mc
|
return mc
|
||||||
|
|
||||||
|
async def _on_reconnect(self):
|
||||||
|
"""Callback invoked by ConnectionManager after a successful reconnect.
|
||||||
|
|
||||||
|
Firmware requires CMD_APP_START after every transport-level connection
|
||||||
|
to initialize the session. MeshCore.connect() does this on the initial
|
||||||
|
connection; this callback ensures it also happens on reconnects (F02).
|
||||||
|
"""
|
||||||
|
await self.commands.send_appstart()
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
await self.dispatcher.start()
|
await self.dispatcher.start()
|
||||||
result = await self.connection_manager.connect()
|
result = await self.connection_manager.connect()
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ class CommandType(Enum):
|
|||||||
SET_AUTOADD_CONFIG = 58
|
SET_AUTOADD_CONFIG = 58
|
||||||
GET_AUTOADD_CONFIG = 59
|
GET_AUTOADD_CONFIG = 59
|
||||||
GET_ALLOWED_REPEAT_FREQ = 60
|
GET_ALLOWED_REPEAT_FREQ = 60
|
||||||
|
GET_STATS = 56 # R04: CMD_GET_STATS — used by get_stats_core/radio/packets
|
||||||
SET_PATH_HASH_MODE = 61
|
SET_PATH_HASH_MODE = 61
|
||||||
|
|
||||||
# Packet prefixes for the protocol
|
# Packet prefixes for the protocol
|
||||||
@@ -120,3 +121,6 @@ class PacketType(Enum):
|
|||||||
PATH_DISCOVERY_RESPONSE = 0x8D
|
PATH_DISCOVERY_RESPONSE = 0x8D
|
||||||
CONTROL_DATA = 0x8E
|
CONTROL_DATA = 0x8E
|
||||||
CONTACT_DELETED = 0x8F
|
CONTACT_DELETED = 0x8F
|
||||||
|
CONTACTS_FULL = 0x90 # N02: MyMesh::onContactsFull() — 1-byte push, no payload
|
||||||
|
# Note: 0x90 == ControlType.NODE_DISCOVER_RESP in a different namespace.
|
||||||
|
# Not a literal conflict (PacketType vs ControlType), but a maintenance hazard.
|
||||||
|
|||||||
@@ -552,9 +552,9 @@ class MessageReader:
|
|||||||
perms = dbuf.read(1)[0]
|
perms = dbuf.read(1)[0]
|
||||||
res["permissions"] = perms
|
res["permissions"] = perms
|
||||||
res["is_admin"] = (perms & 1) == 1 # Check if admin bit is set
|
res["is_admin"] = (perms & 1) == 1 # Check if admin bit is set
|
||||||
|
if len(data) > 7:
|
||||||
res["pubkey_prefix"] = dbuf.read(6).hex()
|
res["pubkey_prefix"] = dbuf.read(6).hex()
|
||||||
|
|
||||||
attributes = {"pubkey_prefix": res.get("pubkey_prefix")}
|
attributes = {"pubkey_prefix": res.get("pubkey_prefix")}
|
||||||
|
|
||||||
await self.dispatcher.dispatch(
|
await self.dispatcher.dispatch(
|
||||||
@@ -942,6 +942,36 @@ class MessageReader:
|
|||||||
await self.dispatcher.dispatch(
|
await self.dispatcher.dispatch(
|
||||||
Event(EventType.DISCOVER_RESPONSE, ndr, attributes)
|
Event(EventType.DISCOVER_RESPONSE, ndr, attributes)
|
||||||
)
|
)
|
||||||
|
elif packet_type_value == PacketType.CONTACT_DELETED.value:
|
||||||
|
# N01: PUSH_CODE_CONTACT_DELETED (0x8F) — 1-byte code + 32-byte pubkey
|
||||||
|
# Emitted by MyMesh::onContactOverwrite() (MyMesh.cpp:325-334)
|
||||||
|
if len(data) < 33:
|
||||||
|
logger.debug("CONTACT_DELETED frame too short (%d bytes, need 33)", len(data))
|
||||||
|
return
|
||||||
|
pubkey = data[1:33].hex()
|
||||||
|
await self.dispatcher.dispatch(
|
||||||
|
Event(EventType.CONTACT_DELETED, {"pubkey": pubkey}, {"pubkey": pubkey})
|
||||||
|
)
|
||||||
|
|
||||||
|
elif packet_type_value == PacketType.CONTACTS_FULL.value:
|
||||||
|
# N02: PUSH_CODE_CONTACTS_FULL (0x90) — 1-byte push, no payload
|
||||||
|
# Emitted by MyMesh::onContactsFull() (MyMesh.cpp:336)
|
||||||
|
await self.dispatcher.dispatch(Event(EventType.CONTACTS_FULL, {}))
|
||||||
|
|
||||||
|
elif packet_type_value == PacketType.TUNING_PARAMS.value:
|
||||||
|
# N03: RESP_CODE_TUNING_PARAMS (23) — response to CMD_GET_TUNING_PARAMS (43)
|
||||||
|
# Format: 1-byte code + 4-byte rx_delay (LE) + 4-byte airtime_factor (LE) = 9 bytes
|
||||||
|
# Emitted by MyMesh.cpp:1307-1313
|
||||||
|
if len(data) < 9:
|
||||||
|
logger.debug("TUNING_PARAMS frame too short (%d bytes, need 9)", len(data))
|
||||||
|
await self.dispatcher.dispatch(
|
||||||
|
Event(EventType.ERROR, {"reason": "invalid_frame_length"})
|
||||||
|
)
|
||||||
|
return
|
||||||
|
rx_delay = int.from_bytes(data[1:5], byteorder="little")
|
||||||
|
airtime_factor = int.from_bytes(data[5:9], byteorder="little")
|
||||||
|
res = {"rx_delay": rx_delay, "airtime_factor": airtime_factor}
|
||||||
|
await self.dispatcher.dispatch(Event(EventType.TUNING_PARAMS, res))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Unhandled data received {data}")
|
logger.debug(f"Unhandled data received {data}")
|
||||||
@@ -953,4 +983,4 @@ class MessageReader:
|
|||||||
e,
|
e,
|
||||||
data.hex(),
|
data.hex(),
|
||||||
traceback.format_exc(),
|
traceback.format_exc(),
|
||||||
)
|
)
|
||||||
@@ -20,11 +20,19 @@ class SerialConnection:
|
|||||||
self._disconnect_callback = None
|
self._disconnect_callback = None
|
||||||
self.cx_dly = cx_dly
|
self.cx_dly = cx_dly
|
||||||
self._connected_event = asyncio.Event()
|
self._connected_event = asyncio.Event()
|
||||||
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
self.frame_expected_size = 0
|
self.frame_expected_size = 0
|
||||||
self.inframe = b""
|
self.inframe = b""
|
||||||
self.header = b""
|
self.header = b""
|
||||||
|
|
||||||
|
def _spawn_background(self, coro) -> asyncio.Task:
|
||||||
|
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||||
|
task = asyncio.create_task(coro)
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
return task
|
||||||
|
|
||||||
class MCSerialClientProtocol(asyncio.Protocol):
|
class MCSerialClientProtocol(asyncio.Protocol):
|
||||||
def __init__(self, cx):
|
def __init__(self, cx):
|
||||||
self.cx = cx
|
self.cx = cx
|
||||||
@@ -44,7 +52,7 @@ class SerialConnection:
|
|||||||
self.cx._connected_event.clear()
|
self.cx._connected_event.clear()
|
||||||
|
|
||||||
if self.cx._disconnect_callback:
|
if self.cx._disconnect_callback:
|
||||||
asyncio.create_task(self.cx._disconnect_callback("serial_disconnect"))
|
self.cx._spawn_background(self.cx._disconnect_callback("serial_disconnect"))
|
||||||
|
|
||||||
def pause_writing(self):
|
def pause_writing(self):
|
||||||
logger.debug("pause writing")
|
logger.debug("pause writing")
|
||||||
@@ -52,12 +60,16 @@ class SerialConnection:
|
|||||||
def resume_writing(self):
|
def resume_writing(self):
|
||||||
logger.debug("resume writing")
|
logger.debug("resume writing")
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self, timeout: float = 10.0):
|
||||||
"""
|
"""
|
||||||
Connects to the device
|
Connects to the device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum seconds to wait for connection_made callback.
|
||||||
|
Defaults to 10.0. Raises asyncio.TimeoutError on expiry.
|
||||||
"""
|
"""
|
||||||
self._connected_event.clear()
|
self._connected_event.clear()
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
await serial_asyncio.create_serial_connection(
|
await serial_asyncio.create_serial_connection(
|
||||||
loop,
|
loop,
|
||||||
@@ -66,7 +78,7 @@ class SerialConnection:
|
|||||||
baudrate=self.baudrate,
|
baudrate=self.baudrate,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._connected_event.wait()
|
await asyncio.wait_for(self._connected_event.wait(), timeout=timeout)
|
||||||
logger.info("Serial Connection started")
|
logger.info("Serial Connection started")
|
||||||
return self.port
|
return self.port
|
||||||
|
|
||||||
@@ -102,7 +114,7 @@ class SerialConnection:
|
|||||||
self.frame_expected_size = 0
|
self.frame_expected_size = 0
|
||||||
if len(data) > 0: # rerun handle_rx on remaining data
|
if len(data) > 0: # rerun handle_rx on remaining data
|
||||||
self.handle_rx(data)
|
self.handle_rx(data)
|
||||||
return
|
return # nothing left to process after reset
|
||||||
|
|
||||||
upbound = self.frame_expected_size - len(self.inframe)
|
upbound = self.frame_expected_size - len(self.inframe)
|
||||||
if len(data) < upbound:
|
if len(data) < upbound:
|
||||||
@@ -114,7 +126,7 @@ class SerialConnection:
|
|||||||
data = data[upbound:]
|
data = data[upbound:]
|
||||||
if self.reader is not None:
|
if self.reader is not None:
|
||||||
# feed meshcore reader
|
# feed meshcore reader
|
||||||
asyncio.create_task(self.reader.handle_rx(self.inframe))
|
self._spawn_background(self.reader.handle_rx(self.inframe))
|
||||||
# reset inframe
|
# reset inframe
|
||||||
self.inframe = b""
|
self.inframe = b""
|
||||||
self.header = b""
|
self.header = b""
|
||||||
@@ -125,11 +137,18 @@ class SerialConnection:
|
|||||||
async def send(self, data):
|
async def send(self, data):
|
||||||
if not self.transport:
|
if not self.transport:
|
||||||
logger.error("Transport not connected, cannot send data")
|
logger.error("Transport not connected, cannot send data")
|
||||||
|
if self._disconnect_callback:
|
||||||
|
await self._disconnect_callback("serial_transport_lost")
|
||||||
return
|
return
|
||||||
size = len(data)
|
size = len(data)
|
||||||
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
||||||
logger.debug(f"sending pkt : {pkt}")
|
logger.debug(f"sending pkt : {pkt}")
|
||||||
self.transport.write(pkt)
|
try:
|
||||||
|
self.transport.write(pkt)
|
||||||
|
except OSError as exc:
|
||||||
|
logger.warning(f"Serial write failed: {exc}")
|
||||||
|
if self._disconnect_callback:
|
||||||
|
await self._disconnect_callback(f"serial_write_failed: {exc}")
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""Close the serial connection."""
|
"""Close the serial connection."""
|
||||||
|
|||||||
@@ -24,6 +24,14 @@ class TCPConnection:
|
|||||||
self.frame_expected_size = 0
|
self.frame_expected_size = 0
|
||||||
self.header = b""
|
self.header = b""
|
||||||
self.inframe = b""
|
self.inframe = b""
|
||||||
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
def _spawn_background(self, coro) -> asyncio.Task:
|
||||||
|
"""Create a tracked background task (prevents GC of fire-and-forget tasks)."""
|
||||||
|
task = asyncio.create_task(coro)
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
return task
|
||||||
|
|
||||||
class MCClientProtocol(asyncio.Protocol):
|
class MCClientProtocol(asyncio.Protocol):
|
||||||
def __init__(self, cx):
|
def __init__(self, cx):
|
||||||
@@ -38,7 +46,6 @@ class TCPConnection:
|
|||||||
|
|
||||||
def data_received(self, data):
|
def data_received(self, data):
|
||||||
logger.debug("data received")
|
logger.debug("data received")
|
||||||
self.cx._receive_count += 1
|
|
||||||
self.cx.handle_rx(data)
|
self.cx.handle_rx(data)
|
||||||
|
|
||||||
def error_received(self, exc):
|
def error_received(self, exc):
|
||||||
@@ -47,7 +54,7 @@ class TCPConnection:
|
|||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
logger.debug("TCP server closed the connection")
|
logger.debug("TCP server closed the connection")
|
||||||
if self.cx._disconnect_callback:
|
if self.cx._disconnect_callback:
|
||||||
asyncio.create_task(self.cx._disconnect_callback("tcp_disconnect"))
|
self.cx._spawn_background(self.cx._disconnect_callback("tcp_disconnect"))
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
"""
|
"""
|
||||||
@@ -59,10 +66,7 @@ class TCPConnection:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info("TCP Connection started")
|
logger.info("TCP Connection started")
|
||||||
future = asyncio.Future()
|
return self.host
|
||||||
future.set_result(self.host)
|
|
||||||
|
|
||||||
return future
|
|
||||||
|
|
||||||
def set_reader(self, reader):
|
def set_reader(self, reader):
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
@@ -96,7 +100,7 @@ class TCPConnection:
|
|||||||
self.frame_expected_size = 0
|
self.frame_expected_size = 0
|
||||||
if len(data) > 0: # rerun handle_rx on remaining data
|
if len(data) > 0: # rerun handle_rx on remaining data
|
||||||
self.handle_rx(data)
|
self.handle_rx(data)
|
||||||
return
|
return # nothing left to process after reset
|
||||||
|
|
||||||
upbound = self.frame_expected_size - len(self.inframe)
|
upbound = self.frame_expected_size - len(self.inframe)
|
||||||
if len(data) < upbound :
|
if len(data) < upbound :
|
||||||
@@ -106,9 +110,13 @@ class TCPConnection:
|
|||||||
|
|
||||||
self.inframe = self.inframe + data[0:upbound]
|
self.inframe = self.inframe + data[0:upbound]
|
||||||
data = data[upbound:]
|
data = data[upbound:]
|
||||||
|
# Increment per completed MeshCore frame, not per TCP segment (N04).
|
||||||
|
# The threshold heuristic in send() compares _send_count vs
|
||||||
|
# _receive_count — counting per-segment skews it under fragmentation.
|
||||||
|
self._receive_count += 1
|
||||||
if self.reader is not None:
|
if self.reader is not None:
|
||||||
# feed meshcore reader
|
# feed meshcore reader
|
||||||
asyncio.create_task(self.reader.handle_rx(self.inframe))
|
self._spawn_background(self.reader.handle_rx(self.inframe))
|
||||||
# reset inframe
|
# reset inframe
|
||||||
self.inframe = b""
|
self.inframe = b""
|
||||||
self.header = b""
|
self.header = b""
|
||||||
@@ -137,7 +145,12 @@ class TCPConnection:
|
|||||||
size = len(data)
|
size = len(data)
|
||||||
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
||||||
logger.debug(f"sending pkt : {pkt}")
|
logger.debug(f"sending pkt : {pkt}")
|
||||||
self.transport.write(pkt)
|
try:
|
||||||
|
self.transport.write(pkt)
|
||||||
|
except (OSError, ConnectionResetError) as exc:
|
||||||
|
logger.warning(f"TCP write failed: {exc}")
|
||||||
|
if self._disconnect_callback:
|
||||||
|
await self._disconnect_callback(f"tcp_write_failed: {exc}")
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""Close the TCP connection."""
|
"""Close the TCP connection."""
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class TestBLEPinPairing(unittest.TestCase):
|
|||||||
|
|
||||||
@patch("meshcore.ble_cx.BleakClient")
|
@patch("meshcore.ble_cx.BleakClient")
|
||||||
def test_ble_connection_with_pin_failed_pairing(self, mock_bleak_client):
|
def test_ble_connection_with_pin_failed_pairing(self, mock_bleak_client):
|
||||||
"""Test BLE connection with PIN when pairing fails but connection continues"""
|
"""Test BLE connection with PIN when pairing fails — re-raises (F17)."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client_instance = self._get_mock_bleak_client()
|
mock_client_instance = self._get_mock_bleak_client()
|
||||||
mock_client_instance.pair = AsyncMock(side_effect=Exception("Pairing failed"))
|
mock_client_instance.pair = AsyncMock(side_effect=Exception("Pairing failed"))
|
||||||
@@ -47,17 +47,16 @@ class TestBLEPinPairing(unittest.TestCase):
|
|||||||
pin = "123456"
|
pin = "123456"
|
||||||
ble_conn = BLEConnection(address=address, pin=pin)
|
ble_conn = BLEConnection(address=address, pin=pin)
|
||||||
|
|
||||||
# Act
|
# Act & Assert — pairing failure now re-raises instead of being
|
||||||
result = asyncio.run(ble_conn.connect())
|
# swallowed, because a half-usable transport is worse than a clean
|
||||||
|
# failure (forensics finding F17).
|
||||||
# Assert
|
with self.assertRaises(Exception) as ctx:
|
||||||
|
asyncio.run(ble_conn.connect())
|
||||||
|
self.assertIn("Pairing failed", str(ctx.exception))
|
||||||
mock_client_instance.connect.assert_called_once()
|
mock_client_instance.connect.assert_called_once()
|
||||||
mock_client_instance.pair.assert_called_once()
|
mock_client_instance.pair.assert_called_once()
|
||||||
mock_client_instance.start_notify.assert_called_once_with(
|
# disconnect should be called to clean up the failed connection
|
||||||
UART_TX_CHAR_UUID, ble_conn.handle_rx
|
mock_client_instance.disconnect.assert_called_once()
|
||||||
)
|
|
||||||
# Connection should still succeed even if pairing fails
|
|
||||||
self.assertEqual(result, address)
|
|
||||||
|
|
||||||
@patch("meshcore.ble_cx.BleakClient")
|
@patch("meshcore.ble_cx.BleakClient")
|
||||||
def test_ble_connection_without_pin_no_pairing(self, mock_bleak_client):
|
def test_ble_connection_without_pin_no_pairing(self, mock_bleak_client):
|
||||||
|
|||||||
235
tests/unit/test_asyncio_lifecycle.py
Normal file
235
tests/unit/test_asyncio_lifecycle.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""
|
||||||
|
Verification tests for asyncio lifecycle fixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import gc
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from meshcore.events import Event, EventDispatcher, EventType
|
||||||
|
from meshcore.tcp_cx import TCPConnection
|
||||||
|
from meshcore.serial_cx import SerialConnection
|
||||||
|
from meshcore.commands.base import CommandHandlerBase
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackgroundTaskTracking(unittest.TestCase):
|
||||||
|
"""Fire-and-forget create_task calls must be tracked to prevent GC."""
|
||||||
|
|
||||||
|
def test_tcp_spawn_background_retains_task(self):
|
||||||
|
"""TCP _spawn_background adds the task to _background_tasks."""
|
||||||
|
async def _run():
|
||||||
|
cx = TCPConnection("127.0.0.1", 5555)
|
||||||
|
completed = asyncio.Event()
|
||||||
|
|
||||||
|
async def dummy():
|
||||||
|
completed.set()
|
||||||
|
|
||||||
|
task = cx._spawn_background(dummy())
|
||||||
|
assert task in cx._background_tasks
|
||||||
|
await completed.wait()
|
||||||
|
# After completion, done_callback should have discarded it
|
||||||
|
await asyncio.sleep(0) # let done callback fire
|
||||||
|
assert task not in cx._background_tasks
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
def test_serial_spawn_background_retains_task(self):
|
||||||
|
"""Serial _spawn_background adds the task to _background_tasks."""
|
||||||
|
async def _run():
|
||||||
|
with patch("meshcore.serial_cx.asyncio.Event") as mock_event:
|
||||||
|
mock_event.return_value = MagicMock()
|
||||||
|
cx = SerialConnection("/dev/null", 115200)
|
||||||
|
completed = asyncio.Event()
|
||||||
|
|
||||||
|
async def dummy():
|
||||||
|
completed.set()
|
||||||
|
|
||||||
|
task = cx._spawn_background(dummy())
|
||||||
|
assert task in cx._background_tasks
|
||||||
|
await completed.wait()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert task not in cx._background_tasks
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
def test_event_dispatcher_spawn_background_retains_task(self):
|
||||||
|
"""EventDispatcher _spawn_background adds task to _background_tasks."""
|
||||||
|
async def _run():
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
completed = asyncio.Event()
|
||||||
|
|
||||||
|
async def dummy():
|
||||||
|
completed.set()
|
||||||
|
|
||||||
|
task = dispatcher._spawn_background(dummy())
|
||||||
|
assert task in dispatcher._background_tasks
|
||||||
|
await completed.wait()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert task not in dispatcher._background_tasks
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
def test_tcp_handle_rx_uses_tracked_task(self):
|
||||||
|
"""TCP handle_rx dispatches reader.handle_rx via _spawn_background."""
|
||||||
|
async def _run():
|
||||||
|
cx = TCPConnection("127.0.0.1", 5555)
|
||||||
|
reader = AsyncMock()
|
||||||
|
reader.handle_rx = AsyncMock()
|
||||||
|
cx.set_reader(reader)
|
||||||
|
|
||||||
|
# Build a minimal valid frame: 0x3e + 2-byte LE size + payload
|
||||||
|
payload = b"\x01\x02\x03"
|
||||||
|
size = len(payload).to_bytes(2, "little")
|
||||||
|
frame = b"\x3e" + size + payload
|
||||||
|
|
||||||
|
cx.handle_rx(frame)
|
||||||
|
# Task should be tracked
|
||||||
|
assert len(cx._background_tasks) == 1
|
||||||
|
# Let task complete
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
reader.handle_rx.assert_awaited_once_with(payload)
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
def test_tcp_connection_lost_uses_tracked_task(self):
|
||||||
|
"""TCP connection_lost dispatches disconnect callback via _spawn_background."""
|
||||||
|
async def _run():
|
||||||
|
cx = TCPConnection("127.0.0.1", 5555)
|
||||||
|
callback = AsyncMock()
|
||||||
|
cx.set_disconnect_callback(callback)
|
||||||
|
|
||||||
|
protocol = cx.MCClientProtocol(cx)
|
||||||
|
protocol.connection_lost(None)
|
||||||
|
|
||||||
|
assert len(cx._background_tasks) == 1
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
callback.assert_awaited_once_with("tcp_disconnect")
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
def test_gc_does_not_cancel_tracked_tasks(self):
|
||||||
|
"""Tracked tasks survive GC pressure (the whole point of tracking)."""
|
||||||
|
async def _run():
|
||||||
|
cx = TCPConnection("127.0.0.1", 5555)
|
||||||
|
result = []
|
||||||
|
|
||||||
|
async def slow_task():
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
result.append("done")
|
||||||
|
|
||||||
|
cx._spawn_background(slow_task())
|
||||||
|
# Force GC — untracked tasks could be collected here
|
||||||
|
gc.collect()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert result == ["done"]
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskDoneCorrectness(unittest.TestCase):
|
||||||
|
"""EventDispatcher.stop() must wait for in-flight async callbacks."""
|
||||||
|
|
||||||
|
def test_stop_waits_for_async_callbacks(self):
|
||||||
|
"""stop() should not return until async callbacks have completed."""
|
||||||
|
async def _run():
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
await dispatcher.start()
|
||||||
|
|
||||||
|
callback_completed = False
|
||||||
|
|
||||||
|
async def slow_callback(event):
|
||||||
|
nonlocal callback_completed
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
callback_completed = True
|
||||||
|
|
||||||
|
dispatcher.subscribe(EventType.OK, slow_callback)
|
||||||
|
await dispatcher.dispatch(Event(EventType.OK, {}))
|
||||||
|
|
||||||
|
# Give the dispatch loop a moment to pick up the event
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
|
||||||
|
# stop() should wait for slow_callback to finish
|
||||||
|
await dispatcher.stop()
|
||||||
|
assert callback_completed, "stop() returned before async callback completed"
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeferredPrimitiveConstruction(unittest.TestCase):
|
||||||
|
"""Queue and Lock must not bind to import-time loop."""
|
||||||
|
|
||||||
|
def test_event_dispatcher_queue_is_none_before_start(self):
|
||||||
|
"""EventDispatcher.queue should be None until start() is called."""
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
assert dispatcher.queue is None
|
||||||
|
|
||||||
|
def test_event_dispatcher_queue_created_on_start(self):
|
||||||
|
"""start() creates the queue."""
|
||||||
|
async def _run():
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
assert dispatcher.queue is None
|
||||||
|
await dispatcher.start()
|
||||||
|
assert dispatcher.queue is not None
|
||||||
|
assert isinstance(dispatcher.queue, asyncio.Queue)
|
||||||
|
await dispatcher.stop()
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
def test_event_dispatcher_dispatch_before_start_raises(self):
|
||||||
|
"""dispatch() before start() should raise RuntimeError."""
|
||||||
|
async def _run():
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
await dispatcher.dispatch(Event(EventType.OK, {}))
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
def test_command_handler_lock_is_none_before_use(self):
|
||||||
|
"""CommandHandlerBase lock should be None until first access."""
|
||||||
|
handler = CommandHandlerBase()
|
||||||
|
assert handler._CommandHandlerBase__mesh_request_lock is None
|
||||||
|
|
||||||
|
def test_command_handler_lock_created_on_access(self):
|
||||||
|
"""Accessing _mesh_request_lock creates it lazily."""
|
||||||
|
async def _run():
|
||||||
|
handler = CommandHandlerBase()
|
||||||
|
lock = handler._mesh_request_lock
|
||||||
|
assert isinstance(lock, asyncio.Lock)
|
||||||
|
# Second access returns same instance
|
||||||
|
assert handler._mesh_request_lock is lock
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetRunningLoop(unittest.TestCase):
|
||||||
|
"""get_event_loop() replaced with get_running_loop() in send()."""
|
||||||
|
|
||||||
|
def test_send_uses_get_running_loop(self):
|
||||||
|
"""send() should call get_running_loop, not get_event_loop."""
|
||||||
|
async def _run():
|
||||||
|
handler = CommandHandlerBase()
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
await dispatcher.start()
|
||||||
|
handler.set_dispatcher(dispatcher)
|
||||||
|
|
||||||
|
mock_sender = AsyncMock()
|
||||||
|
handler._sender_func = mock_sender
|
||||||
|
|
||||||
|
# Patch get_running_loop to verify it's called
|
||||||
|
with patch("meshcore.commands.base.asyncio.get_running_loop", wraps=asyncio.get_running_loop) as mock_grl:
|
||||||
|
# send with expected_events triggers the loop = asyncio.get_running_loop() path
|
||||||
|
result = await handler.send(
|
||||||
|
b"\x01",
|
||||||
|
expected_events=[EventType.OK],
|
||||||
|
timeout=0.05,
|
||||||
|
)
|
||||||
|
mock_grl.assert_called()
|
||||||
|
|
||||||
|
await dispatcher.stop()
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -28,6 +28,10 @@ def mock_dispatcher():
|
|||||||
sub.unsubscribe = MagicMock()
|
sub.unsubscribe = MagicMock()
|
||||||
dispatcher._last_subscribe_handler = handler
|
dispatcher._last_subscribe_handler = handler
|
||||||
dispatcher._last_subscribe_event_type = event_type
|
dispatcher._last_subscribe_event_type = event_type
|
||||||
|
# Immediately resolve the future so send() doesn't block
|
||||||
|
asyncio.get_event_loop().call_soon(
|
||||||
|
handler, Event(event_type, {})
|
||||||
|
)
|
||||||
return sub
|
return sub
|
||||||
|
|
||||||
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||||
@@ -80,6 +84,13 @@ async def test_send_with_event(command_handler, mock_connection, mock_dispatcher
|
|||||||
|
|
||||||
|
|
||||||
async def test_send_timeout(command_handler, mock_connection, mock_dispatcher):
|
async def test_send_timeout(command_handler, mock_connection, mock_dispatcher):
|
||||||
|
# Override to NOT resolve events, so we can test the timeout path
|
||||||
|
def non_resolving_subscribe(event_type, handler, attribute_filters=None):
|
||||||
|
sub = MagicMock(spec=Subscription)
|
||||||
|
sub.unsubscribe = MagicMock()
|
||||||
|
return sub
|
||||||
|
mock_dispatcher.subscribe = MagicMock(side_effect=non_resolving_subscribe)
|
||||||
|
|
||||||
result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1)
|
result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1)
|
||||||
assert result.type == EventType.ERROR
|
assert result.type == EventType.ERROR
|
||||||
assert result.payload == {"reason": "no_event_received"}
|
assert result.payload == {"reason": "no_event_received"}
|
||||||
|
|||||||
294
tests/unit/test_connection_manager.py
Normal file
294
tests/unit/test_connection_manager.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
"""Tests for reconnect-path fixes."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from meshcore.connection_manager import ConnectionManager
|
||||||
|
from meshcore.events import Event, EventDispatcher, EventType
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class FakeConnection:
|
||||||
|
"""Minimal stub that satisfies ConnectionProtocol."""
|
||||||
|
|
||||||
|
def __init__(self, connect_results=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
connect_results: iterator of return values for successive
|
||||||
|
connect() calls. ``None`` means soft failure; a string
|
||||||
|
means success; raising is also supported via sentinel.
|
||||||
|
"""
|
||||||
|
self._connect_results = list(connect_results or ["ok"])
|
||||||
|
self._call_index = 0
|
||||||
|
self.reader = None
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
if self._call_index < len(self._connect_results):
|
||||||
|
result = self._connect_results[self._call_index]
|
||||||
|
self._call_index += 1
|
||||||
|
else:
|
||||||
|
result = self._connect_results[-1]
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send(self, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_reader(self, reader):
|
||||||
|
self.reader = reader
|
||||||
|
|
||||||
|
|
||||||
|
class RaisingConnection(FakeConnection):
|
||||||
|
"""Connection that raises on every connect() attempt."""
|
||||||
|
|
||||||
|
def __init__(self, exc=None):
|
||||||
|
super().__init__()
|
||||||
|
self._exc = exc or ConnectionError("boom")
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
raise self._exc
|
||||||
|
|
||||||
|
|
||||||
|
class _EventCollector:
|
||||||
|
"""Subscribes to all events and records them."""
|
||||||
|
|
||||||
|
def __init__(self, dispatcher: EventDispatcher):
|
||||||
|
self.events: list[Event] = []
|
||||||
|
dispatcher.subscribe(None, self._on_event)
|
||||||
|
|
||||||
|
async def _on_event(self, event: Event):
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TCP connect() should return a plain value, not an asyncio.Future
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tcp_connect_returns_plain_string():
|
||||||
|
"""TCPConnection.connect() returns self.host (a plain string), not an
|
||||||
|
asyncio.Future. We test indirectly via ConnectionManager — the
|
||||||
|
CONNECTED event payload should contain a plain string, not a Future
|
||||||
|
object."""
|
||||||
|
conn = FakeConnection(connect_results=["10.0.0.1"])
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
await dispatcher.start()
|
||||||
|
try:
|
||||||
|
collector = _EventCollector(dispatcher)
|
||||||
|
mgr = ConnectionManager(conn, dispatcher)
|
||||||
|
|
||||||
|
result = await mgr.connect()
|
||||||
|
|
||||||
|
assert result == "10.0.0.1"
|
||||||
|
# Give the dispatcher a moment to deliver the event
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
connected_events = [e for e in collector.events if e.type == EventType.CONNECTED]
|
||||||
|
assert len(connected_events) == 1
|
||||||
|
payload = connected_events[0].payload
|
||||||
|
assert payload["connection_info"] == "10.0.0.1"
|
||||||
|
# The payload value must NOT be an asyncio.Future
|
||||||
|
assert not isinstance(payload["connection_info"], asyncio.Future)
|
||||||
|
finally:
|
||||||
|
await dispatcher.stop()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reconnect attempts must not compound (no tail-recursive create_task)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconnect_loop_does_not_compound():
|
||||||
|
"""_attempt_reconnect must use a single iterative loop. After
|
||||||
|
max_reconnect_attempts failures, exactly that many connect() calls
|
||||||
|
should have been made — no exponential fan-out from orphaned tasks."""
|
||||||
|
# All attempts fail (return None)
|
||||||
|
conn = FakeConnection(connect_results=[None, None, None, None])
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
await dispatcher.start()
|
||||||
|
try:
|
||||||
|
collector = _EventCollector(dispatcher)
|
||||||
|
mgr = ConnectionManager(
|
||||||
|
conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=3,
|
||||||
|
)
|
||||||
|
mgr._is_connected = True # simulate a live connection
|
||||||
|
|
||||||
|
await mgr.handle_disconnect("test_disconnect")
|
||||||
|
# Wait for the reconnect loop to exhaust all attempts
|
||||||
|
# (3 attempts × 1s sleep each, but we can just await the task)
|
||||||
|
if mgr._reconnect_task:
|
||||||
|
await mgr._reconnect_task
|
||||||
|
|
||||||
|
# Exactly 3 connect() calls should have been made
|
||||||
|
assert conn._call_index == 3
|
||||||
|
|
||||||
|
# A DISCONNECTED event with max_attempts_exceeded should have fired
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
disconnected = [e for e in collector.events if e.type == EventType.DISCONNECTED]
|
||||||
|
assert len(disconnected) == 1
|
||||||
|
assert disconnected[0].payload.get("max_attempts_exceeded") is True
|
||||||
|
finally:
|
||||||
|
await dispatcher.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disconnect_cancels_reconnect_loop():
|
||||||
|
"""disconnect() during an active reconnect loop must cancel the
|
||||||
|
single task cleanly — no orphaned tasks left running."""
|
||||||
|
# Simulate a connection that always fails (returns None), giving us
|
||||||
|
# time to call disconnect() mid-loop.
|
||||||
|
conn = FakeConnection(connect_results=[None, None, None, None, None])
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
await dispatcher.start()
|
||||||
|
try:
|
||||||
|
mgr = ConnectionManager(
|
||||||
|
conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=5,
|
||||||
|
)
|
||||||
|
mgr._is_connected = True
|
||||||
|
|
||||||
|
await mgr.handle_disconnect("test_disconnect")
|
||||||
|
|
||||||
|
# Let the first attempt start (wait just past the 1s sleep)
|
||||||
|
await asyncio.sleep(1.2)
|
||||||
|
assert conn._call_index >= 1 # at least one attempt made
|
||||||
|
|
||||||
|
# Now disconnect — should cancel the loop
|
||||||
|
await mgr.disconnect()
|
||||||
|
|
||||||
|
assert mgr._reconnect_task is None
|
||||||
|
calls_at_cancel = conn._call_index
|
||||||
|
|
||||||
|
# Wait a bit and confirm no more attempts happened
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
assert conn._call_index == calls_at_cancel
|
||||||
|
finally:
|
||||||
|
await dispatcher.stop()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# reconnect_callback (send_appstart) is called after reconnect
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconnect_callback_called_after_reconnect():
|
||||||
|
"""When ConnectionManager reconnects successfully, the
|
||||||
|
reconnect_callback (e.g. send_appstart) must be invoked."""
|
||||||
|
callback_called = []
|
||||||
|
|
||||||
|
async def fake_appstart():
|
||||||
|
callback_called.append(True)
|
||||||
|
|
||||||
|
# First connect() fails (None), second succeeds
|
||||||
|
conn = FakeConnection(connect_results=[None, "10.0.0.1"])
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
await dispatcher.start()
|
||||||
|
try:
|
||||||
|
mgr = ConnectionManager(
|
||||||
|
conn, dispatcher,
|
||||||
|
auto_reconnect=True,
|
||||||
|
max_reconnect_attempts=3,
|
||||||
|
reconnect_callback=fake_appstart,
|
||||||
|
)
|
||||||
|
mgr._is_connected = True
|
||||||
|
|
||||||
|
await mgr.handle_disconnect("test_disconnect")
|
||||||
|
if mgr._reconnect_task:
|
||||||
|
await mgr._reconnect_task
|
||||||
|
|
||||||
|
assert len(callback_called) == 1
|
||||||
|
finally:
|
||||||
|
await dispatcher.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconnect_callback_failure_does_not_crash_loop():
|
||||||
|
"""If the reconnect_callback raises, the reconnect still counts as
|
||||||
|
successful (transport is up) — the callback failure is logged but
|
||||||
|
does not crash the loop or leave the manager in a broken state."""
|
||||||
|
async def failing_callback():
|
||||||
|
raise RuntimeError("appstart failed")
|
||||||
|
|
||||||
|
# connect() succeeds on first attempt
|
||||||
|
conn = FakeConnection(connect_results=["10.0.0.1"])
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
await dispatcher.start()
|
||||||
|
try:
|
||||||
|
collector = _EventCollector(dispatcher)
|
||||||
|
mgr = ConnectionManager(
|
||||||
|
conn, dispatcher,
|
||||||
|
auto_reconnect=True,
|
||||||
|
max_reconnect_attempts=3,
|
||||||
|
reconnect_callback=failing_callback,
|
||||||
|
)
|
||||||
|
mgr._is_connected = True
|
||||||
|
|
||||||
|
await mgr.handle_disconnect("test_disconnect")
|
||||||
|
if mgr._reconnect_task:
|
||||||
|
await mgr._reconnect_task
|
||||||
|
|
||||||
|
# Despite callback failure, CONNECTED event should have fired
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
connected = [e for e in collector.events if e.type == EventType.CONNECTED]
|
||||||
|
assert len(connected) == 1
|
||||||
|
assert mgr._is_connected is True
|
||||||
|
finally:
|
||||||
|
await dispatcher.stop()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# connect() returning None is a soft failure (BLE scan miss)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_none_is_soft_failure():
|
||||||
|
"""When connect() returns None (e.g. BLE scan found no device),
|
||||||
|
ConnectionManager.connect() should NOT set _is_connected and should
|
||||||
|
NOT emit a CONNECTED event."""
|
||||||
|
conn = FakeConnection(connect_results=[None])
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
await dispatcher.start()
|
||||||
|
try:
|
||||||
|
collector = _EventCollector(dispatcher)
|
||||||
|
mgr = ConnectionManager(conn, dispatcher)
|
||||||
|
|
||||||
|
result = await mgr.connect()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert mgr._is_connected is False
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
connected = [e for e in collector.events if e.type == EventType.CONNECTED]
|
||||||
|
assert len(connected) == 0
|
||||||
|
finally:
|
||||||
|
await dispatcher.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_reconnect_callback_is_noop():
|
||||||
|
"""When no reconnect_callback is provided (backwards compat for
|
||||||
|
direct ConnectionManager users), reconnect should still work."""
|
||||||
|
conn = FakeConnection(connect_results=["10.0.0.1"])
|
||||||
|
dispatcher = EventDispatcher()
|
||||||
|
await dispatcher.start()
|
||||||
|
try:
|
||||||
|
mgr = ConnectionManager(
|
||||||
|
conn, dispatcher,
|
||||||
|
auto_reconnect=True,
|
||||||
|
max_reconnect_attempts=3,
|
||||||
|
# No reconnect_callback — default None
|
||||||
|
)
|
||||||
|
mgr._is_connected = True
|
||||||
|
|
||||||
|
await mgr.handle_disconnect("test_disconnect")
|
||||||
|
if mgr._reconnect_task:
|
||||||
|
await mgr._reconnect_task
|
||||||
|
|
||||||
|
assert mgr._is_connected is True
|
||||||
|
finally:
|
||||||
|
await dispatcher.stop()
|
||||||
236
tests/unit/test_error_handling.py
Normal file
236
tests/unit/test_error_handling.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""Verification tests for error response handling fixes.
|
||||||
|
|
||||||
|
The tests confirm that error responses are surfaced cleanly instead
|
||||||
|
of causing KeyError, TypeError, NameError, or silent fallthrough.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
from meshcore.commands import CommandHandler
|
||||||
|
from meshcore.events import EventType, Event, Subscription
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
VALID_PUBKEY_HEX = "0123456789abcdef" * 4 # 64 hex chars = 32 bytes
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_connection():
|
||||||
|
connection = MagicMock()
|
||||||
|
connection.send = AsyncMock()
|
||||||
|
return connection
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dispatcher():
|
||||||
|
dispatcher = MagicMock()
|
||||||
|
dispatcher.wait_for_event = AsyncMock()
|
||||||
|
dispatcher.dispatch = AsyncMock()
|
||||||
|
|
||||||
|
def fake_subscribe(event_type, handler, attribute_filters=None):
|
||||||
|
sub = MagicMock(spec=Subscription)
|
||||||
|
sub.unsubscribe = MagicMock()
|
||||||
|
dispatcher._last_subscribe_handler = handler
|
||||||
|
dispatcher._last_subscribe_event_type = event_type
|
||||||
|
return sub
|
||||||
|
|
||||||
|
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||||
|
return dispatcher
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def command_handler(mock_connection, mock_dispatcher):
|
||||||
|
handler = CommandHandler()
|
||||||
|
|
||||||
|
async def sender(data):
|
||||||
|
await mock_connection.send(data)
|
||||||
|
|
||||||
|
handler._sender_func = sender
|
||||||
|
handler.dispatcher = mock_dispatcher
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
def setup_error_response(mock_dispatcher):
|
||||||
|
"""Configure dispatcher to return an ERROR event for any subscribe."""
|
||||||
|
def fake_subscribe(evt_type, handler, attr_filters=None):
|
||||||
|
sub = MagicMock(spec=Subscription)
|
||||||
|
sub.unsubscribe = MagicMock()
|
||||||
|
# Always fire ERROR regardless of which event type was subscribed
|
||||||
|
if evt_type == EventType.ERROR:
|
||||||
|
asyncio.get_event_loop().call_soon(
|
||||||
|
handler, Event(EventType.ERROR, {"reason": "test_error"})
|
||||||
|
)
|
||||||
|
return sub
|
||||||
|
|
||||||
|
mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_event_response(mock_dispatcher, event_type, payload):
|
||||||
|
"""Configure dispatcher to return a specific event."""
|
||||||
|
def fake_subscribe(evt_type, handler, attr_filters=None):
|
||||||
|
sub = MagicMock(spec=Subscription)
|
||||||
|
sub.unsubscribe = MagicMock()
|
||||||
|
if evt_type == event_type:
|
||||||
|
asyncio.get_event_loop().call_soon(
|
||||||
|
handler, Event(event_type, payload)
|
||||||
|
)
|
||||||
|
return sub
|
||||||
|
|
||||||
|
mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Event.is_error() helper ──────────────────────────────────
|
||||||
|
|
||||||
|
async def test_event_is_error_true():
|
||||||
|
"""is_error() returns True for ERROR events."""
|
||||||
|
event = Event(EventType.ERROR, {"reason": "test"})
|
||||||
|
assert event.is_error() is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_event_is_error_false():
|
||||||
|
"""is_error() returns False for non-ERROR events."""
|
||||||
|
event = Event(EventType.OK, {})
|
||||||
|
assert event.is_error() is False
|
||||||
|
event2 = Event(EventType.SELF_INFO, {"name": "test"})
|
||||||
|
assert event2.is_error() is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── send_msg_with_retry continues on ERROR ──────────────
|
||||||
|
|
||||||
|
async def test_send_msg_with_retry_error_no_keyerror(
|
||||||
|
command_handler, mock_dispatcher
|
||||||
|
):
|
||||||
|
"""send_msg_with_retry returns None (exhausted retries) on
|
||||||
|
persistent ERROR instead of raising KeyError on missing 'expected_ack'."""
|
||||||
|
setup_error_response(mock_dispatcher)
|
||||||
|
|
||||||
|
# Provide a mock contact so the path logic doesn't interfere
|
||||||
|
command_handler._get_contact_by_prefix = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
# max_attempts=2 so it retries once then gives up
|
||||||
|
result = await command_handler.send_msg_with_retry(
|
||||||
|
VALID_PUBKEY_HEX, "hello", max_attempts=2, timeout=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return None (no ACK received) rather than raising KeyError
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── send_appstart includes ERROR in expected events ──────────
|
||||||
|
|
||||||
|
async def test_send_appstart_returns_error(
|
||||||
|
command_handler, mock_dispatcher
|
||||||
|
):
|
||||||
|
"""send_appstart returns ERROR event instead of hanging on timeout."""
|
||||||
|
setup_error_response(mock_dispatcher)
|
||||||
|
|
||||||
|
result = await command_handler.send_appstart()
|
||||||
|
|
||||||
|
assert result.type == EventType.ERROR
|
||||||
|
assert result.is_error() is True
|
||||||
|
assert result.payload["reason"] == "test_error"
|
||||||
|
|
||||||
|
|
||||||
|
# ── device setters return ERROR from send_appstart ───────────
|
||||||
|
|
||||||
|
async def test_set_telemetry_mode_base_error(
|
||||||
|
command_handler, mock_dispatcher
|
||||||
|
):
|
||||||
|
"""set_telemetry_mode_base returns ERROR instead of KeyError."""
|
||||||
|
setup_error_response(mock_dispatcher)
|
||||||
|
|
||||||
|
result = await command_handler.set_telemetry_mode_base(1)
|
||||||
|
|
||||||
|
assert result.is_error()
|
||||||
|
assert result.payload["reason"] == "test_error"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_set_telemetry_mode_loc_error(
|
||||||
|
command_handler, mock_dispatcher
|
||||||
|
):
|
||||||
|
"""set_telemetry_mode_loc returns ERROR instead of KeyError."""
|
||||||
|
setup_error_response(mock_dispatcher)
|
||||||
|
|
||||||
|
result = await command_handler.set_telemetry_mode_loc(1)
|
||||||
|
|
||||||
|
assert result.is_error()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_set_telemetry_mode_env_error(
|
||||||
|
command_handler, mock_dispatcher
|
||||||
|
):
|
||||||
|
"""set_telemetry_mode_env returns ERROR instead of KeyError."""
|
||||||
|
setup_error_response(mock_dispatcher)
|
||||||
|
|
||||||
|
result = await command_handler.set_telemetry_mode_env(1)
|
||||||
|
|
||||||
|
assert result.is_error()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_set_manual_add_contacts_error(
|
||||||
|
command_handler, mock_dispatcher
|
||||||
|
):
|
||||||
|
"""set_manual_add_contacts returns ERROR instead of KeyError."""
|
||||||
|
setup_error_response(mock_dispatcher)
|
||||||
|
|
||||||
|
result = await command_handler.set_manual_add_contacts(True)
|
||||||
|
|
||||||
|
assert result.is_error()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_set_advert_loc_policy_error(
|
||||||
|
command_handler, mock_dispatcher
|
||||||
|
):
|
||||||
|
"""set_advert_loc_policy returns ERROR instead of KeyError."""
|
||||||
|
setup_error_response(mock_dispatcher)
|
||||||
|
|
||||||
|
result = await command_handler.set_advert_loc_policy(1)
|
||||||
|
|
||||||
|
assert result.is_error()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_set_multi_acks_error(
|
||||||
|
command_handler, mock_dispatcher
|
||||||
|
):
|
||||||
|
"""set_multi_acks returns ERROR instead of KeyError."""
|
||||||
|
setup_error_response(mock_dispatcher)
|
||||||
|
|
||||||
|
result = await command_handler.set_multi_acks(1)
|
||||||
|
|
||||||
|
assert result.is_error()
|
||||||
|
|
||||||
|
|
||||||
|
# ── send_anon_req returns ERROR on contact not found ─────────
|
||||||
|
|
||||||
|
async def test_send_anon_req_contact_not_found(
|
||||||
|
command_handler, mock_dispatcher
|
||||||
|
):
|
||||||
|
"""send_anon_req returns ERROR event when contact prefix not found,
|
||||||
|
instead of raising TypeError on NoneType subscript."""
|
||||||
|
command_handler._get_contact_by_prefix = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
result = await command_handler.send_anon_req(
|
||||||
|
VALID_PUBKEY_HEX, MagicMock(value=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.is_error()
|
||||||
|
assert result.payload["reason"] == "contact_not_found"
|
||||||
|
|
||||||
|
|
||||||
|
# ── send_trace handles unknown path_hash_len without NameError ──
|
||||||
|
|
||||||
|
async def test_send_trace_unknown_path_hash_len(
|
||||||
|
command_handler, mock_connection, mock_dispatcher
|
||||||
|
):
|
||||||
|
"""send_trace with a path whose segments don't match any known
|
||||||
|
path_hash_len returns ERROR cleanly instead of NameError on 'e'."""
|
||||||
|
# 5-char hex segments → path_hash_len = 2.5 → doesn't match 1,2,4,8
|
||||||
|
result = await command_handler.send_trace(
|
||||||
|
auth_code=0, tag=1, flags=None, path="abcde"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.is_error()
|
||||||
|
assert result.payload["reason"] == "invalid_path_format"
|
||||||
364
tests/unit/test_protocol_surface_gaps.py
Normal file
364
tests/unit/test_protocol_surface_gaps.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
"""Verification tests for protocol surface gaps.
|
||||||
|
|
||||||
|
Each test constructs a mock firmware frame and verifies the SDK dispatches
|
||||||
|
the correct EventType with the expected payload fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import struct
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from meshcore.events import Event, EventType, EventDispatcher
|
||||||
|
from meshcore.reader import MessageReader
|
||||||
|
from meshcore.packets import PacketType, CommandType
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_reader():
|
||||||
|
"""Create a MessageReader with a mock dispatcher that records dispatched events."""
|
||||||
|
dispatcher = MagicMock(spec=EventDispatcher)
|
||||||
|
dispatched = []
|
||||||
|
|
||||||
|
async def _capture(event):
|
||||||
|
dispatched.append(event)
|
||||||
|
|
||||||
|
dispatcher.dispatch = AsyncMock(side_effect=_capture)
|
||||||
|
reader = MessageReader(dispatcher)
|
||||||
|
return reader, dispatched
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CONTACT_DELETED handler
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_contact_deleted_dispatches_event():
|
||||||
|
"""A 33-byte CONTACT_DELETED frame dispatches EventType.CONTACT_DELETED."""
|
||||||
|
reader, dispatched = _make_reader()
|
||||||
|
pubkey = bytes(range(32))
|
||||||
|
frame = bytes([PacketType.CONTACT_DELETED.value]) + pubkey
|
||||||
|
assert len(frame) == 33
|
||||||
|
|
||||||
|
await reader.handle_rx(bytearray(frame))
|
||||||
|
|
||||||
|
assert len(dispatched) == 1
|
||||||
|
evt = dispatched[0]
|
||||||
|
assert evt.type == EventType.CONTACT_DELETED
|
||||||
|
assert evt.payload["pubkey"] == pubkey.hex()
|
||||||
|
assert evt.attributes["pubkey"] == pubkey.hex()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_contact_deleted_short_frame_ignored():
|
||||||
|
"""A CONTACT_DELETED frame shorter than 33 bytes is silently dropped."""
|
||||||
|
reader, dispatched = _make_reader()
|
||||||
|
# Only 10 bytes — too short
|
||||||
|
frame = bytes([PacketType.CONTACT_DELETED.value]) + b"\x00" * 9
|
||||||
|
|
||||||
|
await reader.handle_rx(bytearray(frame))
|
||||||
|
|
||||||
|
assert len(dispatched) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CONTACTS_FULL handler + enum entry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_contacts_full_enum_exists():
|
||||||
|
"""PacketType.CONTACTS_FULL == 0x90."""
|
||||||
|
assert PacketType.CONTACTS_FULL.value == 0x90
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_contacts_full_dispatches_event():
|
||||||
|
"""A 1-byte CONTACTS_FULL push dispatches EventType.CONTACTS_FULL."""
|
||||||
|
reader, dispatched = _make_reader()
|
||||||
|
frame = bytes([PacketType.CONTACTS_FULL.value])
|
||||||
|
|
||||||
|
await reader.handle_rx(bytearray(frame))
|
||||||
|
|
||||||
|
assert len(dispatched) == 1
|
||||||
|
evt = dispatched[0]
|
||||||
|
assert evt.type == EventType.CONTACTS_FULL
|
||||||
|
assert evt.payload == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TUNING_PARAMS handler
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tuning_params_dispatches_event():
|
||||||
|
"""A 9-byte TUNING_PARAMS frame dispatches with rx_delay and airtime_factor."""
|
||||||
|
reader, dispatched = _make_reader()
|
||||||
|
rx_delay = 500
|
||||||
|
airtime_factor = 200
|
||||||
|
frame = (
|
||||||
|
bytes([PacketType.TUNING_PARAMS.value])
|
||||||
|
+ rx_delay.to_bytes(4, "little")
|
||||||
|
+ airtime_factor.to_bytes(4, "little")
|
||||||
|
)
|
||||||
|
assert len(frame) == 9
|
||||||
|
|
||||||
|
await reader.handle_rx(bytearray(frame))
|
||||||
|
|
||||||
|
assert len(dispatched) == 1
|
||||||
|
evt = dispatched[0]
|
||||||
|
assert evt.type == EventType.TUNING_PARAMS
|
||||||
|
assert evt.payload["rx_delay"] == 500
|
||||||
|
assert evt.payload["airtime_factor"] == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tuning_params_short_frame_dispatches_error():
|
||||||
|
"""A TUNING_PARAMS frame shorter than 9 bytes dispatches ERROR."""
|
||||||
|
reader, dispatched = _make_reader()
|
||||||
|
# Only 5 bytes — too short
|
||||||
|
frame = bytes([PacketType.TUNING_PARAMS.value]) + b"\x01\x00\x00\x00"
|
||||||
|
|
||||||
|
await reader.handle_rx(bytearray(frame))
|
||||||
|
|
||||||
|
assert len(dispatched) == 1
|
||||||
|
evt = dispatched[0]
|
||||||
|
assert evt.type == EventType.ERROR
|
||||||
|
assert evt.payload["reason"] == "invalid_frame_length"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# send_trace() one-byte pad
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_trace_empty_path_pads_to_11_bytes():
|
||||||
|
"""send_trace() with no path produces an 11-byte packet (not 10)."""
|
||||||
|
from meshcore.commands.messaging import MessagingCommands
|
||||||
|
|
||||||
|
cmd = MessagingCommands.__new__(MessagingCommands)
|
||||||
|
|
||||||
|
captured_data = None
|
||||||
|
|
||||||
|
async def mock_send(data, expected_events, timeout=None):
|
||||||
|
nonlocal captured_data
|
||||||
|
captured_data = bytes(data)
|
||||||
|
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
|
||||||
|
|
||||||
|
cmd.send = mock_send
|
||||||
|
|
||||||
|
await cmd.send_trace(auth_code=0, tag=1, flags=0, path=None)
|
||||||
|
|
||||||
|
assert captured_data is not None
|
||||||
|
# cmd(1) + tag(4) + auth(4) + flags(1) + pad(1) = 11
|
||||||
|
assert len(captured_data) == 11
|
||||||
|
assert captured_data[-1] == 0x00 # The pad byte
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_trace_with_path_no_padding():
|
||||||
|
"""send_trace() with a non-empty path does NOT add padding."""
|
||||||
|
from meshcore.commands.messaging import MessagingCommands
|
||||||
|
|
||||||
|
cmd = MessagingCommands.__new__(MessagingCommands)
|
||||||
|
|
||||||
|
captured_data = None
|
||||||
|
|
||||||
|
async def mock_send(data, expected_events, timeout=None):
|
||||||
|
nonlocal captured_data
|
||||||
|
captured_data = bytes(data)
|
||||||
|
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
|
||||||
|
|
||||||
|
cmd.send = mock_send
|
||||||
|
|
||||||
|
# 2-byte path hash (flags=1 means hash_len=2)
|
||||||
|
await cmd.send_trace(auth_code=0, tag=1, flags=1, path=b"\xAA\xBB")
|
||||||
|
|
||||||
|
assert captured_data is not None
|
||||||
|
# cmd(1) + tag(4) + auth(4) + flags(1) + path(2) = 12 — no pad needed
|
||||||
|
assert len(captured_data) == 12
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Command wrapper: send_raw_data
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_raw_data_wrapper():
|
||||||
|
"""send_raw_data sends CMD 0x19 + payload."""
|
||||||
|
from meshcore.commands.messaging import MessagingCommands
|
||||||
|
|
||||||
|
cmd = MessagingCommands.__new__(MessagingCommands)
|
||||||
|
|
||||||
|
captured_data = None
|
||||||
|
|
||||||
|
async def mock_send(data, expected_events, timeout=None):
|
||||||
|
nonlocal captured_data
|
||||||
|
captured_data = bytes(data)
|
||||||
|
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
|
||||||
|
|
||||||
|
cmd.send = mock_send
|
||||||
|
|
||||||
|
await cmd.send_raw_data(b"\xDE\xAD")
|
||||||
|
|
||||||
|
assert captured_data is not None
|
||||||
|
assert captured_data[0] == 0x19 # CMD_SEND_RAW_DATA
|
||||||
|
assert captured_data[1:] == b"\xDE\xAD"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Command wrapper: has_connection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_has_connection_wrapper():
|
||||||
|
"""has_connection sends CMD 0x1c."""
|
||||||
|
from meshcore.commands.device import DeviceCommands
|
||||||
|
|
||||||
|
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||||
|
|
||||||
|
captured_data = None
|
||||||
|
|
||||||
|
async def mock_send(data, expected_events, timeout=None):
|
||||||
|
nonlocal captured_data
|
||||||
|
captured_data = bytes(data)
|
||||||
|
return Event(EventType.OK, {"value": 1})
|
||||||
|
|
||||||
|
cmd.send = mock_send
|
||||||
|
|
||||||
|
await cmd.has_connection()
|
||||||
|
|
||||||
|
assert captured_data is not None
|
||||||
|
assert captured_data == b"\x1c"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Command wrapper: get_tuning
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_tuning_wrapper():
|
||||||
|
"""get_tuning sends CMD 0x2b (GET_TUNING_PARAMS = 43)."""
|
||||||
|
from meshcore.commands.device import DeviceCommands
|
||||||
|
|
||||||
|
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||||
|
|
||||||
|
captured_data = None
|
||||||
|
|
||||||
|
async def mock_send(data, expected_events, timeout=None):
|
||||||
|
nonlocal captured_data
|
||||||
|
captured_data = bytes(data)
|
||||||
|
return Event(EventType.TUNING_PARAMS, {"rx_delay": 500, "airtime_factor": 200})
|
||||||
|
|
||||||
|
cmd.send = mock_send
|
||||||
|
|
||||||
|
result = await cmd.get_tuning()
|
||||||
|
|
||||||
|
assert captured_data == b"\x2b"
|
||||||
|
assert result.type == EventType.TUNING_PARAMS
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Command wrapper: get_contact_by_key
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_contact_by_key_wrapper():
|
||||||
|
"""get_contact_by_key sends CMD 0x1e + 32-byte pubkey."""
|
||||||
|
from meshcore.commands.contact import ContactCommands
|
||||||
|
|
||||||
|
cmd = ContactCommands.__new__(ContactCommands)
|
||||||
|
|
||||||
|
captured_data = None
|
||||||
|
|
||||||
|
async def mock_send(data, expected_events, timeout=None):
|
||||||
|
nonlocal captured_data
|
||||||
|
captured_data = bytes(data)
|
||||||
|
return Event(EventType.NEXT_CONTACT, {"public_key": "ab" * 32})
|
||||||
|
|
||||||
|
cmd.send = mock_send
|
||||||
|
|
||||||
|
pubkey = bytes(range(32))
|
||||||
|
await cmd.get_contact_by_key(pubkey)
|
||||||
|
|
||||||
|
assert captured_data is not None
|
||||||
|
assert captured_data[0] == 0x1E # CMD_GET_CONTACT_BY_KEY
|
||||||
|
assert captured_data[1:] == pubkey
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Command wrapper: factory_reset (two-step)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_factory_reset_two_step():
|
||||||
|
"""factory_reset requires a token from request_factory_reset."""
|
||||||
|
from meshcore.commands.device import DeviceCommands
|
||||||
|
|
||||||
|
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||||
|
|
||||||
|
captured_data = None
|
||||||
|
|
||||||
|
async def mock_send(data, expected_events, timeout=None):
|
||||||
|
nonlocal captured_data
|
||||||
|
captured_data = bytes(data)
|
||||||
|
return Event(EventType.OK, {})
|
||||||
|
|
||||||
|
cmd.send = mock_send
|
||||||
|
|
||||||
|
# Step 1: request token
|
||||||
|
token = await cmd.request_factory_reset()
|
||||||
|
assert isinstance(token, str)
|
||||||
|
assert len(token) == 16 # hex-encoded 8 bytes
|
||||||
|
|
||||||
|
# Step 2: confirm with wrong token fails
|
||||||
|
with pytest.raises(ValueError, match="Invalid or expired"):
|
||||||
|
await cmd.confirm_factory_reset("wrong_token")
|
||||||
|
|
||||||
|
# Step 2: confirm with correct token succeeds
|
||||||
|
await cmd.confirm_factory_reset(token)
|
||||||
|
assert captured_data == b"\x33" # CMD_FACTORY_RESET
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_factory_reset_without_request_fails():
|
||||||
|
"""confirm_factory_reset without request_factory_reset raises ValueError."""
|
||||||
|
from meshcore.commands.device import DeviceCommands
|
||||||
|
|
||||||
|
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid or expired"):
|
||||||
|
await cmd.confirm_factory_reset("any_token")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET_STATS enum entry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_get_stats_enum_exists():
|
||||||
|
"""CommandType.GET_STATS == 56."""
|
||||||
|
assert CommandType.GET_STATS.value == 56
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats_core_uses_enum():
|
||||||
|
"""get_stats_core sends CommandType.GET_STATS.value (0x38) + 0x00."""
|
||||||
|
from meshcore.commands.device import DeviceCommands
|
||||||
|
|
||||||
|
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||||
|
|
||||||
|
captured_data = None
|
||||||
|
|
||||||
|
async def mock_send(data, expected_events, timeout=None):
|
||||||
|
nonlocal captured_data
|
||||||
|
captured_data = bytes(data)
|
||||||
|
return Event(EventType.STATS_CORE, {})
|
||||||
|
|
||||||
|
cmd.send = mock_send
|
||||||
|
|
||||||
|
await cmd.get_stats_core()
|
||||||
|
|
||||||
|
assert captured_data is not None
|
||||||
|
assert captured_data[0] == CommandType.GET_STATS.value # 0x38 = 56
|
||||||
|
assert captured_data[1] == 0x00
|
||||||
238
tests/unit/test_transport_symmetry.py
Normal file
238
tests/unit/test_transport_symmetry.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""
|
||||||
|
Verification tests for transport symmetry fixes.
|
||||||
|
|
||||||
|
Covers: send symmetry across transports, serial disconnect callback on
|
||||||
|
transport-lost, serial connect timeout, oversize-frame return, BLE
|
||||||
|
disconnect-callback re-registration, BLE pairing failure re-raise,
|
||||||
|
TCP counter per frame not per segment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from meshcore.tcp_cx import TCPConnection
|
||||||
|
from meshcore.serial_cx import SerialConnection
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RecordingReader:
|
||||||
|
"""Minimal reader mock that records dispatched frames."""
|
||||||
|
def __init__(self):
|
||||||
|
self.frames = []
|
||||||
|
|
||||||
|
async def handle_rx(self, data):
|
||||||
|
self.frames.append(bytes(data))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TCP send() wraps transport.write in try/except
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tcp_send_write_error_fires_disconnect():
|
||||||
|
"""TCP: OSError during transport.write fires _disconnect_callback."""
|
||||||
|
cx = TCPConnection("127.0.0.1", 5000)
|
||||||
|
cb = AsyncMock()
|
||||||
|
cx.set_disconnect_callback(cb)
|
||||||
|
|
||||||
|
mock_transport = MagicMock()
|
||||||
|
mock_transport.write.side_effect = OSError("Broken pipe")
|
||||||
|
cx.transport = mock_transport
|
||||||
|
cx._send_count = 0
|
||||||
|
cx._receive_count = 0
|
||||||
|
|
||||||
|
await cx.send(b"\x01\x02\x03")
|
||||||
|
|
||||||
|
cb.assert_awaited_once()
|
||||||
|
reason = cb.call_args[0][0]
|
||||||
|
assert "tcp_write_failed" in reason
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Serial send() fires disconnect on transport-lost and write error
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serial_send_no_transport_fires_disconnect():
|
||||||
|
"""Serial: send() on None transport fires _disconnect_callback ."""
|
||||||
|
cx = SerialConnection("/dev/null", 115200)
|
||||||
|
cb = AsyncMock()
|
||||||
|
cx.set_disconnect_callback(cb)
|
||||||
|
cx.transport = None
|
||||||
|
|
||||||
|
await cx.send(b"\x01")
|
||||||
|
|
||||||
|
cb.assert_awaited_once()
|
||||||
|
reason = cb.call_args[0][0]
|
||||||
|
assert reason == "serial_transport_lost"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serial_send_write_error_fires_disconnect():
|
||||||
|
"""Serial: OSError during transport.write fires _disconnect_callback."""
|
||||||
|
cx = SerialConnection("/dev/null", 115200)
|
||||||
|
cb = AsyncMock()
|
||||||
|
cx.set_disconnect_callback(cb)
|
||||||
|
|
||||||
|
mock_transport = MagicMock()
|
||||||
|
mock_transport.write.side_effect = OSError("Device not configured")
|
||||||
|
cx.transport = mock_transport
|
||||||
|
|
||||||
|
await cx.send(b"\x01")
|
||||||
|
|
||||||
|
cb.assert_awaited_once()
|
||||||
|
reason = cb.call_args[0][0]
|
||||||
|
assert "serial_write_failed" in reason
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BLE send() fires disconnect on transport-lost and write error
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ble_send_no_client_fires_disconnect():
|
||||||
|
"""BLE: send() with no client fires _disconnect_callback."""
|
||||||
|
# Can't import BLEConnection directly if bleak isn't installed,
|
||||||
|
# so test via dynamic import with a guard.
|
||||||
|
try:
|
||||||
|
from meshcore.ble_cx import BLEConnection
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("bleak not installed")
|
||||||
|
|
||||||
|
# BLEConnection.__init__ checks BLEAK_AVAILABLE; patch it
|
||||||
|
with patch("meshcore.ble_cx.BLEAK_AVAILABLE", True), \
|
||||||
|
patch("meshcore.ble_cx.BleakClient", MagicMock()):
|
||||||
|
cx = BLEConnection.__new__(BLEConnection)
|
||||||
|
cx.client = None
|
||||||
|
cx._user_provided_client = None
|
||||||
|
cx._user_provided_address = None
|
||||||
|
cx._user_provided_device = None
|
||||||
|
cx.address = None
|
||||||
|
cx.device = None
|
||||||
|
cx.pin = None
|
||||||
|
cx.rx_char = None
|
||||||
|
cb = AsyncMock()
|
||||||
|
cx._disconnect_callback = cb
|
||||||
|
|
||||||
|
result = await cx.send(b"\x01")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
cb.assert_awaited_once()
|
||||||
|
reason = cb.call_args[0][0]
|
||||||
|
assert reason == "ble_transport_lost"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Serial connect() times out if connection_made never fires
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serial_connect_timeout():
|
||||||
|
"""Serial: connect() raises TimeoutError if connection_made never fires."""
|
||||||
|
cx = SerialConnection("/dev/null", 115200)
|
||||||
|
|
||||||
|
# Mock create_serial_connection to do nothing (never fires connection_made)
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return (MagicMock(), MagicMock())
|
||||||
|
|
||||||
|
with patch("meshcore.serial_cx.serial_asyncio.create_serial_connection",
|
||||||
|
side_effect=mock_create):
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await cx.connect(timeout=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Oversize frame resets state and returns without dispatch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tcp_oversize_frame_empty_data_returns():
|
||||||
|
"""TCP: oversize header with no trailing data returns without dispatch."""
|
||||||
|
cx = TCPConnection("127.0.0.1", 5000)
|
||||||
|
reader = RecordingReader()
|
||||||
|
cx.set_reader(reader)
|
||||||
|
|
||||||
|
# Build a frame header with size > 300 and no payload data after header
|
||||||
|
# Header: 0x3e + 2-byte LE size (e.g. 500 = 0x01F4)
|
||||||
|
header = b"\x3e" + (500).to_bytes(2, "little")
|
||||||
|
cx.handle_rx(header)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# No frames should be dispatched, and state should be reset
|
||||||
|
assert reader.frames == []
|
||||||
|
assert cx.header == b""
|
||||||
|
assert cx.inframe == b""
|
||||||
|
assert cx.frame_expected_size == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serial_oversize_frame_empty_data_returns():
|
||||||
|
"""Serial: oversize header with no trailing data returns without dispatch."""
|
||||||
|
cx = SerialConnection("/dev/null", 115200)
|
||||||
|
reader = RecordingReader()
|
||||||
|
cx.set_reader(reader)
|
||||||
|
|
||||||
|
header = b"\x3e" + (500).to_bytes(2, "little")
|
||||||
|
cx.handle_rx(header)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert reader.frames == []
|
||||||
|
assert cx.header == b""
|
||||||
|
assert cx.inframe == b""
|
||||||
|
assert cx.frame_expected_size == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TCP receive counter increments per MeshCore frame, not per TCP segment
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tcp_receive_count_per_frame_not_per_segment():
|
||||||
|
"""TCP: _receive_count increments per completed frame, not per data_received call."""
|
||||||
|
cx = TCPConnection("127.0.0.1", 5000)
|
||||||
|
reader = RecordingReader()
|
||||||
|
cx.set_reader(reader)
|
||||||
|
cx._receive_count = 0
|
||||||
|
|
||||||
|
# Build a 4-byte payload frame
|
||||||
|
payload = b"\xAA\xBB\xCC\xDD"
|
||||||
|
frame = b"\x3e" + len(payload).to_bytes(2, "little") + payload
|
||||||
|
|
||||||
|
# Split the frame into 3 TCP segments (simulating fragmentation)
|
||||||
|
protocol = TCPConnection.MCClientProtocol(cx)
|
||||||
|
protocol.data_received(frame[:2]) # partial header
|
||||||
|
protocol.data_received(frame[2:5]) # rest of header + 2 bytes payload
|
||||||
|
protocol.data_received(frame[5:]) # remaining payload
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# 3 data_received calls but only 1 completed frame
|
||||||
|
assert cx._receive_count == 1
|
||||||
|
assert reader.frames == [payload]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tcp_multiple_frames_count_correctly():
|
||||||
|
"""TCP: two complete frames in separate segments → _receive_count == 2."""
|
||||||
|
cx = TCPConnection("127.0.0.1", 5000)
|
||||||
|
reader = RecordingReader()
|
||||||
|
cx.set_reader(reader)
|
||||||
|
cx._receive_count = 0
|
||||||
|
|
||||||
|
payload1 = b"\x01\x02"
|
||||||
|
frame1 = b"\x3e" + len(payload1).to_bytes(2, "little") + payload1
|
||||||
|
payload2 = b"\x03\x04\x05"
|
||||||
|
frame2 = b"\x3e" + len(payload2).to_bytes(2, "little") + payload2
|
||||||
|
|
||||||
|
protocol = TCPConnection.MCClientProtocol(cx)
|
||||||
|
protocol.data_received(frame1)
|
||||||
|
protocol.data_received(frame2)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert cx._receive_count == 2
|
||||||
|
assert reader.frames == [payload1, payload2]
|
||||||
Reference in New Issue
Block a user