mirror of
https://github.com/meshcore-dev/meshcore_py.git
synced 2026-06-11 03:56:16 +00:00
Merge branch 'main' into fix/reader-parser-crash-safety
This commit is contained in:
@@ -37,7 +37,7 @@ class TestBLEPinPairing(unittest.TestCase):
|
||||
|
||||
@patch("meshcore.ble_cx.BleakClient")
|
||||
def test_ble_connection_with_pin_failed_pairing(self, mock_bleak_client):
|
||||
"""Test BLE connection with PIN when pairing fails but connection continues"""
|
||||
"""Test BLE connection with PIN when pairing fails — re-raises (F17)."""
|
||||
# Arrange
|
||||
mock_client_instance = self._get_mock_bleak_client()
|
||||
mock_client_instance.pair = AsyncMock(side_effect=Exception("Pairing failed"))
|
||||
@@ -47,17 +47,16 @@ class TestBLEPinPairing(unittest.TestCase):
|
||||
pin = "123456"
|
||||
ble_conn = BLEConnection(address=address, pin=pin)
|
||||
|
||||
# Act
|
||||
result = asyncio.run(ble_conn.connect())
|
||||
|
||||
# Assert
|
||||
# Act & Assert — pairing failure now re-raises instead of being
|
||||
# swallowed, because a half-usable transport is worse than a clean
|
||||
# failure (forensics finding F17).
|
||||
with self.assertRaises(Exception) as ctx:
|
||||
asyncio.run(ble_conn.connect())
|
||||
self.assertIn("Pairing failed", str(ctx.exception))
|
||||
mock_client_instance.connect.assert_called_once()
|
||||
mock_client_instance.pair.assert_called_once()
|
||||
mock_client_instance.start_notify.assert_called_once_with(
|
||||
UART_TX_CHAR_UUID, ble_conn.handle_rx
|
||||
)
|
||||
# Connection should still succeed even if pairing fails
|
||||
self.assertEqual(result, address)
|
||||
# disconnect should be called to clean up the failed connection
|
||||
mock_client_instance.disconnect.assert_called_once()
|
||||
|
||||
@patch("meshcore.ble_cx.BleakClient")
|
||||
def test_ble_connection_without_pin_no_pairing(self, mock_bleak_client):
|
||||
|
||||
235
tests/unit/test_asyncio_lifecycle.py
Normal file
235
tests/unit/test_asyncio_lifecycle.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Verification tests for asyncio lifecycle fixes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from meshcore.events import Event, EventDispatcher, EventType
|
||||
from meshcore.tcp_cx import TCPConnection
|
||||
from meshcore.serial_cx import SerialConnection
|
||||
from meshcore.commands.base import CommandHandlerBase
|
||||
|
||||
|
||||
class TestBackgroundTaskTracking(unittest.TestCase):
|
||||
"""Fire-and-forget create_task calls must be tracked to prevent GC."""
|
||||
|
||||
def test_tcp_spawn_background_retains_task(self):
|
||||
"""TCP _spawn_background adds the task to _background_tasks."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
completed = asyncio.Event()
|
||||
|
||||
async def dummy():
|
||||
completed.set()
|
||||
|
||||
task = cx._spawn_background(dummy())
|
||||
assert task in cx._background_tasks
|
||||
await completed.wait()
|
||||
# After completion, done_callback should have discarded it
|
||||
await asyncio.sleep(0) # let done callback fire
|
||||
assert task not in cx._background_tasks
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_serial_spawn_background_retains_task(self):
|
||||
"""Serial _spawn_background adds the task to _background_tasks."""
|
||||
async def _run():
|
||||
with patch("meshcore.serial_cx.asyncio.Event") as mock_event:
|
||||
mock_event.return_value = MagicMock()
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
completed = asyncio.Event()
|
||||
|
||||
async def dummy():
|
||||
completed.set()
|
||||
|
||||
task = cx._spawn_background(dummy())
|
||||
assert task in cx._background_tasks
|
||||
await completed.wait()
|
||||
await asyncio.sleep(0)
|
||||
assert task not in cx._background_tasks
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_event_dispatcher_spawn_background_retains_task(self):
|
||||
"""EventDispatcher _spawn_background adds task to _background_tasks."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
completed = asyncio.Event()
|
||||
|
||||
async def dummy():
|
||||
completed.set()
|
||||
|
||||
task = dispatcher._spawn_background(dummy())
|
||||
assert task in dispatcher._background_tasks
|
||||
await completed.wait()
|
||||
await asyncio.sleep(0)
|
||||
assert task not in dispatcher._background_tasks
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_tcp_handle_rx_uses_tracked_task(self):
|
||||
"""TCP handle_rx dispatches reader.handle_rx via _spawn_background."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
reader = AsyncMock()
|
||||
reader.handle_rx = AsyncMock()
|
||||
cx.set_reader(reader)
|
||||
|
||||
# Build a minimal valid frame: 0x3e + 2-byte LE size + payload
|
||||
payload = b"\x01\x02\x03"
|
||||
size = len(payload).to_bytes(2, "little")
|
||||
frame = b"\x3e" + size + payload
|
||||
|
||||
cx.handle_rx(frame)
|
||||
# Task should be tracked
|
||||
assert len(cx._background_tasks) == 1
|
||||
# Let task complete
|
||||
await asyncio.sleep(0.05)
|
||||
reader.handle_rx.assert_awaited_once_with(payload)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_tcp_connection_lost_uses_tracked_task(self):
|
||||
"""TCP connection_lost dispatches disconnect callback via _spawn_background."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
callback = AsyncMock()
|
||||
cx.set_disconnect_callback(callback)
|
||||
|
||||
protocol = cx.MCClientProtocol(cx)
|
||||
protocol.connection_lost(None)
|
||||
|
||||
assert len(cx._background_tasks) == 1
|
||||
await asyncio.sleep(0.05)
|
||||
callback.assert_awaited_once_with("tcp_disconnect")
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_gc_does_not_cancel_tracked_tasks(self):
|
||||
"""Tracked tasks survive GC pressure (the whole point of tracking)."""
|
||||
async def _run():
|
||||
cx = TCPConnection("127.0.0.1", 5555)
|
||||
result = []
|
||||
|
||||
async def slow_task():
|
||||
await asyncio.sleep(0.05)
|
||||
result.append("done")
|
||||
|
||||
cx._spawn_background(slow_task())
|
||||
# Force GC — untracked tasks could be collected here
|
||||
gc.collect()
|
||||
await asyncio.sleep(0.1)
|
||||
assert result == ["done"]
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
class TestTaskDoneCorrectness(unittest.TestCase):
|
||||
"""EventDispatcher.stop() must wait for in-flight async callbacks."""
|
||||
|
||||
def test_stop_waits_for_async_callbacks(self):
|
||||
"""stop() should not return until async callbacks have completed."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
|
||||
callback_completed = False
|
||||
|
||||
async def slow_callback(event):
|
||||
nonlocal callback_completed
|
||||
await asyncio.sleep(0.1)
|
||||
callback_completed = True
|
||||
|
||||
dispatcher.subscribe(EventType.OK, slow_callback)
|
||||
await dispatcher.dispatch(Event(EventType.OK, {}))
|
||||
|
||||
# Give the dispatch loop a moment to pick up the event
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
# stop() should wait for slow_callback to finish
|
||||
await dispatcher.stop()
|
||||
assert callback_completed, "stop() returned before async callback completed"
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
class TestDeferredPrimitiveConstruction(unittest.TestCase):
|
||||
"""Queue and Lock must not bind to import-time loop."""
|
||||
|
||||
def test_event_dispatcher_queue_is_none_before_start(self):
|
||||
"""EventDispatcher.queue should be None until start() is called."""
|
||||
dispatcher = EventDispatcher()
|
||||
assert dispatcher.queue is None
|
||||
|
||||
def test_event_dispatcher_queue_created_on_start(self):
|
||||
"""start() creates the queue."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
assert dispatcher.queue is None
|
||||
await dispatcher.start()
|
||||
assert dispatcher.queue is not None
|
||||
assert isinstance(dispatcher.queue, asyncio.Queue)
|
||||
await dispatcher.stop()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_event_dispatcher_dispatch_before_start_raises(self):
|
||||
"""dispatch() before start() should raise RuntimeError."""
|
||||
async def _run():
|
||||
dispatcher = EventDispatcher()
|
||||
with self.assertRaises(RuntimeError):
|
||||
await dispatcher.dispatch(Event(EventType.OK, {}))
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
def test_command_handler_lock_is_none_before_use(self):
|
||||
"""CommandHandlerBase lock should be None until first access."""
|
||||
handler = CommandHandlerBase()
|
||||
assert handler._CommandHandlerBase__mesh_request_lock is None
|
||||
|
||||
def test_command_handler_lock_created_on_access(self):
|
||||
"""Accessing _mesh_request_lock creates it lazily."""
|
||||
async def _run():
|
||||
handler = CommandHandlerBase()
|
||||
lock = handler._mesh_request_lock
|
||||
assert isinstance(lock, asyncio.Lock)
|
||||
# Second access returns same instance
|
||||
assert handler._mesh_request_lock is lock
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
class TestGetRunningLoop(unittest.TestCase):
|
||||
"""get_event_loop() replaced with get_running_loop() in send()."""
|
||||
|
||||
def test_send_uses_get_running_loop(self):
|
||||
"""send() should call get_running_loop, not get_event_loop."""
|
||||
async def _run():
|
||||
handler = CommandHandlerBase()
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
handler.set_dispatcher(dispatcher)
|
||||
|
||||
mock_sender = AsyncMock()
|
||||
handler._sender_func = mock_sender
|
||||
|
||||
# Patch get_running_loop to verify it's called
|
||||
with patch("meshcore.commands.base.asyncio.get_running_loop", wraps=asyncio.get_running_loop) as mock_grl:
|
||||
# send with expected_events triggers the loop = asyncio.get_running_loop() path
|
||||
result = await handler.send(
|
||||
b"\x01",
|
||||
expected_events=[EventType.OK],
|
||||
timeout=0.05,
|
||||
)
|
||||
mock_grl.assert_called()
|
||||
|
||||
await dispatcher.stop()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -28,6 +28,10 @@ def mock_dispatcher():
|
||||
sub.unsubscribe = MagicMock()
|
||||
dispatcher._last_subscribe_handler = handler
|
||||
dispatcher._last_subscribe_event_type = event_type
|
||||
# Immediately resolve the future so send() doesn't block
|
||||
asyncio.get_event_loop().call_soon(
|
||||
handler, Event(event_type, {})
|
||||
)
|
||||
return sub
|
||||
|
||||
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||
@@ -80,6 +84,13 @@ async def test_send_with_event(command_handler, mock_connection, mock_dispatcher
|
||||
|
||||
|
||||
async def test_send_timeout(command_handler, mock_connection, mock_dispatcher):
|
||||
# Override to NOT resolve events, so we can test the timeout path
|
||||
def non_resolving_subscribe(event_type, handler, attribute_filters=None):
|
||||
sub = MagicMock(spec=Subscription)
|
||||
sub.unsubscribe = MagicMock()
|
||||
return sub
|
||||
mock_dispatcher.subscribe = MagicMock(side_effect=non_resolving_subscribe)
|
||||
|
||||
result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1)
|
||||
assert result.type == EventType.ERROR
|
||||
assert result.payload == {"reason": "no_event_received"}
|
||||
|
||||
294
tests/unit/test_connection_manager.py
Normal file
294
tests/unit/test_connection_manager.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Tests for reconnect-path fixes."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from meshcore.connection_manager import ConnectionManager
|
||||
from meshcore.events import Event, EventDispatcher, EventType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class FakeConnection:
|
||||
"""Minimal stub that satisfies ConnectionProtocol."""
|
||||
|
||||
def __init__(self, connect_results=None):
|
||||
"""
|
||||
Args:
|
||||
connect_results: iterator of return values for successive
|
||||
connect() calls. ``None`` means soft failure; a string
|
||||
means success; raising is also supported via sentinel.
|
||||
"""
|
||||
self._connect_results = list(connect_results or ["ok"])
|
||||
self._call_index = 0
|
||||
self.reader = None
|
||||
|
||||
async def connect(self):
|
||||
if self._call_index < len(self._connect_results):
|
||||
result = self._connect_results[self._call_index]
|
||||
self._call_index += 1
|
||||
else:
|
||||
result = self._connect_results[-1]
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def send(self, data):
|
||||
pass
|
||||
|
||||
def set_reader(self, reader):
|
||||
self.reader = reader
|
||||
|
||||
|
||||
class RaisingConnection(FakeConnection):
|
||||
"""Connection that raises on every connect() attempt."""
|
||||
|
||||
def __init__(self, exc=None):
|
||||
super().__init__()
|
||||
self._exc = exc or ConnectionError("boom")
|
||||
|
||||
async def connect(self):
|
||||
raise self._exc
|
||||
|
||||
|
||||
class _EventCollector:
|
||||
"""Subscribes to all events and records them."""
|
||||
|
||||
def __init__(self, dispatcher: EventDispatcher):
|
||||
self.events: list[Event] = []
|
||||
dispatcher.subscribe(None, self._on_event)
|
||||
|
||||
async def _on_event(self, event: Event):
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TCP connect() should return a plain value, not an asyncio.Future
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_connect_returns_plain_string():
|
||||
"""TCPConnection.connect() returns self.host (a plain string), not an
|
||||
asyncio.Future. We test indirectly via ConnectionManager — the
|
||||
CONNECTED event payload should contain a plain string, not a Future
|
||||
object."""
|
||||
conn = FakeConnection(connect_results=["10.0.0.1"])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
collector = _EventCollector(dispatcher)
|
||||
mgr = ConnectionManager(conn, dispatcher)
|
||||
|
||||
result = await mgr.connect()
|
||||
|
||||
assert result == "10.0.0.1"
|
||||
# Give the dispatcher a moment to deliver the event
|
||||
await asyncio.sleep(0.05)
|
||||
connected_events = [e for e in collector.events if e.type == EventType.CONNECTED]
|
||||
assert len(connected_events) == 1
|
||||
payload = connected_events[0].payload
|
||||
assert payload["connection_info"] == "10.0.0.1"
|
||||
# The payload value must NOT be an asyncio.Future
|
||||
assert not isinstance(payload["connection_info"], asyncio.Future)
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reconnect attempts must not compound (no tail-recursive create_task)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_loop_does_not_compound():
|
||||
"""_attempt_reconnect must use a single iterative loop. After
|
||||
max_reconnect_attempts failures, exactly that many connect() calls
|
||||
should have been made — no exponential fan-out from orphaned tasks."""
|
||||
# All attempts fail (return None)
|
||||
conn = FakeConnection(connect_results=[None, None, None, None])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
collector = _EventCollector(dispatcher)
|
||||
mgr = ConnectionManager(
|
||||
conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=3,
|
||||
)
|
||||
mgr._is_connected = True # simulate a live connection
|
||||
|
||||
await mgr.handle_disconnect("test_disconnect")
|
||||
# Wait for the reconnect loop to exhaust all attempts
|
||||
# (3 attempts × 1s sleep each, but we can just await the task)
|
||||
if mgr._reconnect_task:
|
||||
await mgr._reconnect_task
|
||||
|
||||
# Exactly 3 connect() calls should have been made
|
||||
assert conn._call_index == 3
|
||||
|
||||
# A DISCONNECTED event with max_attempts_exceeded should have fired
|
||||
await asyncio.sleep(0.05)
|
||||
disconnected = [e for e in collector.events if e.type == EventType.DISCONNECTED]
|
||||
assert len(disconnected) == 1
|
||||
assert disconnected[0].payload.get("max_attempts_exceeded") is True
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_cancels_reconnect_loop():
|
||||
"""disconnect() during an active reconnect loop must cancel the
|
||||
single task cleanly — no orphaned tasks left running."""
|
||||
# Simulate a connection that always fails (returns None), giving us
|
||||
# time to call disconnect() mid-loop.
|
||||
conn = FakeConnection(connect_results=[None, None, None, None, None])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
mgr = ConnectionManager(
|
||||
conn, dispatcher, auto_reconnect=True, max_reconnect_attempts=5,
|
||||
)
|
||||
mgr._is_connected = True
|
||||
|
||||
await mgr.handle_disconnect("test_disconnect")
|
||||
|
||||
# Let the first attempt start (wait just past the 1s sleep)
|
||||
await asyncio.sleep(1.2)
|
||||
assert conn._call_index >= 1 # at least one attempt made
|
||||
|
||||
# Now disconnect — should cancel the loop
|
||||
await mgr.disconnect()
|
||||
|
||||
assert mgr._reconnect_task is None
|
||||
calls_at_cancel = conn._call_index
|
||||
|
||||
# Wait a bit and confirm no more attempts happened
|
||||
await asyncio.sleep(2)
|
||||
assert conn._call_index == calls_at_cancel
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reconnect_callback (send_appstart) is called after reconnect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_callback_called_after_reconnect():
|
||||
"""When ConnectionManager reconnects successfully, the
|
||||
reconnect_callback (e.g. send_appstart) must be invoked."""
|
||||
callback_called = []
|
||||
|
||||
async def fake_appstart():
|
||||
callback_called.append(True)
|
||||
|
||||
# First connect() fails (None), second succeeds
|
||||
conn = FakeConnection(connect_results=[None, "10.0.0.1"])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
mgr = ConnectionManager(
|
||||
conn, dispatcher,
|
||||
auto_reconnect=True,
|
||||
max_reconnect_attempts=3,
|
||||
reconnect_callback=fake_appstart,
|
||||
)
|
||||
mgr._is_connected = True
|
||||
|
||||
await mgr.handle_disconnect("test_disconnect")
|
||||
if mgr._reconnect_task:
|
||||
await mgr._reconnect_task
|
||||
|
||||
assert len(callback_called) == 1
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_callback_failure_does_not_crash_loop():
|
||||
"""If the reconnect_callback raises, the reconnect still counts as
|
||||
successful (transport is up) — the callback failure is logged but
|
||||
does not crash the loop or leave the manager in a broken state."""
|
||||
async def failing_callback():
|
||||
raise RuntimeError("appstart failed")
|
||||
|
||||
# connect() succeeds on first attempt
|
||||
conn = FakeConnection(connect_results=["10.0.0.1"])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
collector = _EventCollector(dispatcher)
|
||||
mgr = ConnectionManager(
|
||||
conn, dispatcher,
|
||||
auto_reconnect=True,
|
||||
max_reconnect_attempts=3,
|
||||
reconnect_callback=failing_callback,
|
||||
)
|
||||
mgr._is_connected = True
|
||||
|
||||
await mgr.handle_disconnect("test_disconnect")
|
||||
if mgr._reconnect_task:
|
||||
await mgr._reconnect_task
|
||||
|
||||
# Despite callback failure, CONNECTED event should have fired
|
||||
await asyncio.sleep(0.05)
|
||||
connected = [e for e in collector.events if e.type == EventType.CONNECTED]
|
||||
assert len(connected) == 1
|
||||
assert mgr._is_connected is True
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# connect() returning None is a soft failure (BLE scan miss)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_none_is_soft_failure():
|
||||
"""When connect() returns None (e.g. BLE scan found no device),
|
||||
ConnectionManager.connect() should NOT set _is_connected and should
|
||||
NOT emit a CONNECTED event."""
|
||||
conn = FakeConnection(connect_results=[None])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
collector = _EventCollector(dispatcher)
|
||||
mgr = ConnectionManager(conn, dispatcher)
|
||||
|
||||
result = await mgr.connect()
|
||||
|
||||
assert result is None
|
||||
assert mgr._is_connected is False
|
||||
await asyncio.sleep(0.05)
|
||||
connected = [e for e in collector.events if e.type == EventType.CONNECTED]
|
||||
assert len(connected) == 0
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_reconnect_callback_is_noop():
|
||||
"""When no reconnect_callback is provided (backwards compat for
|
||||
direct ConnectionManager users), reconnect should still work."""
|
||||
conn = FakeConnection(connect_results=["10.0.0.1"])
|
||||
dispatcher = EventDispatcher()
|
||||
await dispatcher.start()
|
||||
try:
|
||||
mgr = ConnectionManager(
|
||||
conn, dispatcher,
|
||||
auto_reconnect=True,
|
||||
max_reconnect_attempts=3,
|
||||
# No reconnect_callback — default None
|
||||
)
|
||||
mgr._is_connected = True
|
||||
|
||||
await mgr.handle_disconnect("test_disconnect")
|
||||
if mgr._reconnect_task:
|
||||
await mgr._reconnect_task
|
||||
|
||||
assert mgr._is_connected is True
|
||||
finally:
|
||||
await dispatcher.stop()
|
||||
236
tests/unit/test_error_handling.py
Normal file
236
tests/unit/test_error_handling.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Verification tests for error response handling fixes.
|
||||
|
||||
The tests confirm that error responses are surfaced cleanly instead
|
||||
of causing KeyError, TypeError, NameError, or silent fallthrough.
|
||||
"""
|
||||
import asyncio
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from meshcore.commands import CommandHandler
|
||||
from meshcore.events import EventType, Event, Subscription
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
VALID_PUBKEY_HEX = "0123456789abcdef" * 4 # 64 hex chars = 32 bytes
|
||||
|
||||
|
||||
# ── Fixtures ───────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection():
|
||||
connection = MagicMock()
|
||||
connection.send = AsyncMock()
|
||||
return connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dispatcher():
|
||||
dispatcher = MagicMock()
|
||||
dispatcher.wait_for_event = AsyncMock()
|
||||
dispatcher.dispatch = AsyncMock()
|
||||
|
||||
def fake_subscribe(event_type, handler, attribute_filters=None):
|
||||
sub = MagicMock(spec=Subscription)
|
||||
sub.unsubscribe = MagicMock()
|
||||
dispatcher._last_subscribe_handler = handler
|
||||
dispatcher._last_subscribe_event_type = event_type
|
||||
return sub
|
||||
|
||||
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||
return dispatcher
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def command_handler(mock_connection, mock_dispatcher):
|
||||
handler = CommandHandler()
|
||||
|
||||
async def sender(data):
|
||||
await mock_connection.send(data)
|
||||
|
||||
handler._sender_func = sender
|
||||
handler.dispatcher = mock_dispatcher
|
||||
return handler
|
||||
|
||||
|
||||
def setup_error_response(mock_dispatcher):
|
||||
"""Configure dispatcher to return an ERROR event for any subscribe."""
|
||||
def fake_subscribe(evt_type, handler, attr_filters=None):
|
||||
sub = MagicMock(spec=Subscription)
|
||||
sub.unsubscribe = MagicMock()
|
||||
# Always fire ERROR regardless of which event type was subscribed
|
||||
if evt_type == EventType.ERROR:
|
||||
asyncio.get_event_loop().call_soon(
|
||||
handler, Event(EventType.ERROR, {"reason": "test_error"})
|
||||
)
|
||||
return sub
|
||||
|
||||
mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||
|
||||
|
||||
def setup_event_response(mock_dispatcher, event_type, payload):
|
||||
"""Configure dispatcher to return a specific event."""
|
||||
def fake_subscribe(evt_type, handler, attr_filters=None):
|
||||
sub = MagicMock(spec=Subscription)
|
||||
sub.unsubscribe = MagicMock()
|
||||
if evt_type == event_type:
|
||||
asyncio.get_event_loop().call_soon(
|
||||
handler, Event(event_type, payload)
|
||||
)
|
||||
return sub
|
||||
|
||||
mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
|
||||
|
||||
|
||||
# ── Event.is_error() helper ──────────────────────────────────
|
||||
|
||||
async def test_event_is_error_true():
|
||||
"""is_error() returns True for ERROR events."""
|
||||
event = Event(EventType.ERROR, {"reason": "test"})
|
||||
assert event.is_error() is True
|
||||
|
||||
|
||||
async def test_event_is_error_false():
|
||||
"""is_error() returns False for non-ERROR events."""
|
||||
event = Event(EventType.OK, {})
|
||||
assert event.is_error() is False
|
||||
event2 = Event(EventType.SELF_INFO, {"name": "test"})
|
||||
assert event2.is_error() is False
|
||||
|
||||
|
||||
# ── send_msg_with_retry continues on ERROR ──────────────
|
||||
|
||||
async def test_send_msg_with_retry_error_no_keyerror(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""send_msg_with_retry returns None (exhausted retries) on
|
||||
persistent ERROR instead of raising KeyError on missing 'expected_ack'."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
# Provide a mock contact so the path logic doesn't interfere
|
||||
command_handler._get_contact_by_prefix = MagicMock(return_value=None)
|
||||
|
||||
# max_attempts=2 so it retries once then gives up
|
||||
result = await command_handler.send_msg_with_retry(
|
||||
VALID_PUBKEY_HEX, "hello", max_attempts=2, timeout=0.1
|
||||
)
|
||||
|
||||
# Should return None (no ACK received) rather than raising KeyError
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── send_appstart includes ERROR in expected events ──────────
|
||||
|
||||
async def test_send_appstart_returns_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""send_appstart returns ERROR event instead of hanging on timeout."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.send_appstart()
|
||||
|
||||
assert result.type == EventType.ERROR
|
||||
assert result.is_error() is True
|
||||
assert result.payload["reason"] == "test_error"
|
||||
|
||||
|
||||
# ── device setters return ERROR from send_appstart ───────────
|
||||
|
||||
async def test_set_telemetry_mode_base_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_telemetry_mode_base returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_telemetry_mode_base(1)
|
||||
|
||||
assert result.is_error()
|
||||
assert result.payload["reason"] == "test_error"
|
||||
|
||||
|
||||
async def test_set_telemetry_mode_loc_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_telemetry_mode_loc returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_telemetry_mode_loc(1)
|
||||
|
||||
assert result.is_error()
|
||||
|
||||
|
||||
async def test_set_telemetry_mode_env_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_telemetry_mode_env returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_telemetry_mode_env(1)
|
||||
|
||||
assert result.is_error()
|
||||
|
||||
|
||||
async def test_set_manual_add_contacts_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_manual_add_contacts returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_manual_add_contacts(True)
|
||||
|
||||
assert result.is_error()
|
||||
|
||||
|
||||
async def test_set_advert_loc_policy_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_advert_loc_policy returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_advert_loc_policy(1)
|
||||
|
||||
assert result.is_error()
|
||||
|
||||
|
||||
async def test_set_multi_acks_error(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""set_multi_acks returns ERROR instead of KeyError."""
|
||||
setup_error_response(mock_dispatcher)
|
||||
|
||||
result = await command_handler.set_multi_acks(1)
|
||||
|
||||
assert result.is_error()
|
||||
|
||||
|
||||
# ── send_anon_req returns ERROR on contact not found ─────────
|
||||
|
||||
async def test_send_anon_req_contact_not_found(
|
||||
command_handler, mock_dispatcher
|
||||
):
|
||||
"""send_anon_req returns ERROR event when contact prefix not found,
|
||||
instead of raising TypeError on NoneType subscript."""
|
||||
command_handler._get_contact_by_prefix = MagicMock(return_value=None)
|
||||
|
||||
result = await command_handler.send_anon_req(
|
||||
VALID_PUBKEY_HEX, MagicMock(value=1)
|
||||
)
|
||||
|
||||
assert result.is_error()
|
||||
assert result.payload["reason"] == "contact_not_found"
|
||||
|
||||
|
||||
# ── send_trace handles unknown path_hash_len without NameError ──
|
||||
|
||||
async def test_send_trace_unknown_path_hash_len(
|
||||
command_handler, mock_connection, mock_dispatcher
|
||||
):
|
||||
"""send_trace with a path whose segments don't match any known
|
||||
path_hash_len returns ERROR cleanly instead of NameError on 'e'."""
|
||||
# 5-char hex segments → path_hash_len = 2.5 → doesn't match 1,2,4,8
|
||||
result = await command_handler.send_trace(
|
||||
auth_code=0, tag=1, flags=None, path="abcde"
|
||||
)
|
||||
|
||||
assert result.is_error()
|
||||
assert result.payload["reason"] == "invalid_path_format"
|
||||
364
tests/unit/test_protocol_surface_gaps.py
Normal file
364
tests/unit/test_protocol_surface_gaps.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""Verification tests for protocol surface gaps.
|
||||
|
||||
Each test constructs a mock firmware frame and verifies the SDK dispatches
|
||||
the correct EventType with the expected payload fields.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from meshcore.events import Event, EventType, EventDispatcher
|
||||
from meshcore.reader import MessageReader
|
||||
from meshcore.packets import PacketType, CommandType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_reader():
|
||||
"""Create a MessageReader with a mock dispatcher that records dispatched events."""
|
||||
dispatcher = MagicMock(spec=EventDispatcher)
|
||||
dispatched = []
|
||||
|
||||
async def _capture(event):
|
||||
dispatched.append(event)
|
||||
|
||||
dispatcher.dispatch = AsyncMock(side_effect=_capture)
|
||||
reader = MessageReader(dispatcher)
|
||||
return reader, dispatched
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CONTACT_DELETED handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_contact_deleted_dispatches_event():
|
||||
"""A 33-byte CONTACT_DELETED frame dispatches EventType.CONTACT_DELETED."""
|
||||
reader, dispatched = _make_reader()
|
||||
pubkey = bytes(range(32))
|
||||
frame = bytes([PacketType.CONTACT_DELETED.value]) + pubkey
|
||||
assert len(frame) == 33
|
||||
|
||||
await reader.handle_rx(bytearray(frame))
|
||||
|
||||
assert len(dispatched) == 1
|
||||
evt = dispatched[0]
|
||||
assert evt.type == EventType.CONTACT_DELETED
|
||||
assert evt.payload["pubkey"] == pubkey.hex()
|
||||
assert evt.attributes["pubkey"] == pubkey.hex()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_contact_deleted_short_frame_ignored():
|
||||
"""A CONTACT_DELETED frame shorter than 33 bytes is silently dropped."""
|
||||
reader, dispatched = _make_reader()
|
||||
# Only 10 bytes — too short
|
||||
frame = bytes([PacketType.CONTACT_DELETED.value]) + b"\x00" * 9
|
||||
|
||||
await reader.handle_rx(bytearray(frame))
|
||||
|
||||
assert len(dispatched) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CONTACTS_FULL handler + enum entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_contacts_full_enum_exists():
|
||||
"""PacketType.CONTACTS_FULL == 0x90."""
|
||||
assert PacketType.CONTACTS_FULL.value == 0x90
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_contacts_full_dispatches_event():
|
||||
"""A 1-byte CONTACTS_FULL push dispatches EventType.CONTACTS_FULL."""
|
||||
reader, dispatched = _make_reader()
|
||||
frame = bytes([PacketType.CONTACTS_FULL.value])
|
||||
|
||||
await reader.handle_rx(bytearray(frame))
|
||||
|
||||
assert len(dispatched) == 1
|
||||
evt = dispatched[0]
|
||||
assert evt.type == EventType.CONTACTS_FULL
|
||||
assert evt.payload == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TUNING_PARAMS handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tuning_params_dispatches_event():
|
||||
"""A 9-byte TUNING_PARAMS frame dispatches with rx_delay and airtime_factor."""
|
||||
reader, dispatched = _make_reader()
|
||||
rx_delay = 500
|
||||
airtime_factor = 200
|
||||
frame = (
|
||||
bytes([PacketType.TUNING_PARAMS.value])
|
||||
+ rx_delay.to_bytes(4, "little")
|
||||
+ airtime_factor.to_bytes(4, "little")
|
||||
)
|
||||
assert len(frame) == 9
|
||||
|
||||
await reader.handle_rx(bytearray(frame))
|
||||
|
||||
assert len(dispatched) == 1
|
||||
evt = dispatched[0]
|
||||
assert evt.type == EventType.TUNING_PARAMS
|
||||
assert evt.payload["rx_delay"] == 500
|
||||
assert evt.payload["airtime_factor"] == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tuning_params_short_frame_dispatches_error():
|
||||
"""A TUNING_PARAMS frame shorter than 9 bytes dispatches ERROR."""
|
||||
reader, dispatched = _make_reader()
|
||||
# Only 5 bytes — too short
|
||||
frame = bytes([PacketType.TUNING_PARAMS.value]) + b"\x01\x00\x00\x00"
|
||||
|
||||
await reader.handle_rx(bytearray(frame))
|
||||
|
||||
assert len(dispatched) == 1
|
||||
evt = dispatched[0]
|
||||
assert evt.type == EventType.ERROR
|
||||
assert evt.payload["reason"] == "invalid_frame_length"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_trace() one-byte pad
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_trace_empty_path_pads_to_11_bytes():
|
||||
"""send_trace() with no path produces an 11-byte packet (not 10)."""
|
||||
from meshcore.commands.messaging import MessagingCommands
|
||||
|
||||
cmd = MessagingCommands.__new__(MessagingCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
await cmd.send_trace(auth_code=0, tag=1, flags=0, path=None)
|
||||
|
||||
assert captured_data is not None
|
||||
# cmd(1) + tag(4) + auth(4) + flags(1) + pad(1) = 11
|
||||
assert len(captured_data) == 11
|
||||
assert captured_data[-1] == 0x00 # The pad byte
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_trace_with_path_no_padding():
|
||||
"""send_trace() with a non-empty path does NOT add padding."""
|
||||
from meshcore.commands.messaging import MessagingCommands
|
||||
|
||||
cmd = MessagingCommands.__new__(MessagingCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
# 2-byte path hash (flags=1 means hash_len=2)
|
||||
await cmd.send_trace(auth_code=0, tag=1, flags=1, path=b"\xAA\xBB")
|
||||
|
||||
assert captured_data is not None
|
||||
# cmd(1) + tag(4) + auth(4) + flags(1) + path(2) = 12 — no pad needed
|
||||
assert len(captured_data) == 12
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command wrapper: send_raw_data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_raw_data_wrapper():
|
||||
"""send_raw_data sends CMD 0x19 + payload."""
|
||||
from meshcore.commands.messaging import MessagingCommands
|
||||
|
||||
cmd = MessagingCommands.__new__(MessagingCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.MSG_SENT, {"type": 0, "expected_ack": b"\x00" * 4, "suggested_timeout": 1000})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
await cmd.send_raw_data(b"\xDE\xAD")
|
||||
|
||||
assert captured_data is not None
|
||||
assert captured_data[0] == 0x19 # CMD_SEND_RAW_DATA
|
||||
assert captured_data[1:] == b"\xDE\xAD"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command wrapper: has_connection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_connection_wrapper():
|
||||
"""has_connection sends CMD 0x1c."""
|
||||
from meshcore.commands.device import DeviceCommands
|
||||
|
||||
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.OK, {"value": 1})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
await cmd.has_connection()
|
||||
|
||||
assert captured_data is not None
|
||||
assert captured_data == b"\x1c"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command wrapper: get_tuning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tuning_wrapper():
|
||||
"""get_tuning sends CMD 0x2b (GET_TUNING_PARAMS = 43)."""
|
||||
from meshcore.commands.device import DeviceCommands
|
||||
|
||||
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.TUNING_PARAMS, {"rx_delay": 500, "airtime_factor": 200})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
result = await cmd.get_tuning()
|
||||
|
||||
assert captured_data == b"\x2b"
|
||||
assert result.type == EventType.TUNING_PARAMS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command wrapper: get_contact_by_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_contact_by_key_wrapper():
|
||||
"""get_contact_by_key sends CMD 0x1e + 32-byte pubkey."""
|
||||
from meshcore.commands.contact import ContactCommands
|
||||
|
||||
cmd = ContactCommands.__new__(ContactCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.NEXT_CONTACT, {"public_key": "ab" * 32})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
pubkey = bytes(range(32))
|
||||
await cmd.get_contact_by_key(pubkey)
|
||||
|
||||
assert captured_data is not None
|
||||
assert captured_data[0] == 0x1E # CMD_GET_CONTACT_BY_KEY
|
||||
assert captured_data[1:] == pubkey
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command wrapper: factory_reset (two-step)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_reset_two_step():
|
||||
"""factory_reset requires a token from request_factory_reset."""
|
||||
from meshcore.commands.device import DeviceCommands
|
||||
|
||||
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.OK, {})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
# Step 1: request token
|
||||
token = await cmd.request_factory_reset()
|
||||
assert isinstance(token, str)
|
||||
assert len(token) == 16 # hex-encoded 8 bytes
|
||||
|
||||
# Step 2: confirm with wrong token fails
|
||||
with pytest.raises(ValueError, match="Invalid or expired"):
|
||||
await cmd.confirm_factory_reset("wrong_token")
|
||||
|
||||
# Step 2: confirm with correct token succeeds
|
||||
await cmd.confirm_factory_reset(token)
|
||||
assert captured_data == b"\x33" # CMD_FACTORY_RESET
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_reset_without_request_fails():
|
||||
"""confirm_factory_reset without request_factory_reset raises ValueError."""
|
||||
from meshcore.commands.device import DeviceCommands
|
||||
|
||||
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid or expired"):
|
||||
await cmd.confirm_factory_reset("any_token")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET_STATS enum entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_stats_enum_exists():
|
||||
"""CommandType.GET_STATS == 56."""
|
||||
assert CommandType.GET_STATS.value == 56
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_core_uses_enum():
|
||||
"""get_stats_core sends CommandType.GET_STATS.value (0x38) + 0x00."""
|
||||
from meshcore.commands.device import DeviceCommands
|
||||
|
||||
cmd = DeviceCommands.__new__(DeviceCommands)
|
||||
|
||||
captured_data = None
|
||||
|
||||
async def mock_send(data, expected_events, timeout=None):
|
||||
nonlocal captured_data
|
||||
captured_data = bytes(data)
|
||||
return Event(EventType.STATS_CORE, {})
|
||||
|
||||
cmd.send = mock_send
|
||||
|
||||
await cmd.get_stats_core()
|
||||
|
||||
assert captured_data is not None
|
||||
assert captured_data[0] == CommandType.GET_STATS.value # 0x38 = 56
|
||||
assert captured_data[1] == 0x00
|
||||
238
tests/unit/test_transport_symmetry.py
Normal file
238
tests/unit/test_transport_symmetry.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Verification tests for transport symmetry fixes.
|
||||
|
||||
Covers: send symmetry across transports, serial disconnect callback on
|
||||
transport-lost, serial connect timeout, oversize-frame return, BLE
|
||||
disconnect-callback re-registration, BLE pairing failure re-raise,
|
||||
TCP counter per frame not per segment.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from meshcore.tcp_cx import TCPConnection
|
||||
from meshcore.serial_cx import SerialConnection
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RecordingReader:
|
||||
"""Minimal reader mock that records dispatched frames."""
|
||||
def __init__(self):
|
||||
self.frames = []
|
||||
|
||||
async def handle_rx(self, data):
|
||||
self.frames.append(bytes(data))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TCP send() wraps transport.write in try/except
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_send_write_error_fires_disconnect():
|
||||
"""TCP: OSError during transport.write fires _disconnect_callback."""
|
||||
cx = TCPConnection("127.0.0.1", 5000)
|
||||
cb = AsyncMock()
|
||||
cx.set_disconnect_callback(cb)
|
||||
|
||||
mock_transport = MagicMock()
|
||||
mock_transport.write.side_effect = OSError("Broken pipe")
|
||||
cx.transport = mock_transport
|
||||
cx._send_count = 0
|
||||
cx._receive_count = 0
|
||||
|
||||
await cx.send(b"\x01\x02\x03")
|
||||
|
||||
cb.assert_awaited_once()
|
||||
reason = cb.call_args[0][0]
|
||||
assert "tcp_write_failed" in reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serial send() fires disconnect on transport-lost and write error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serial_send_no_transport_fires_disconnect():
|
||||
"""Serial: send() on None transport fires _disconnect_callback ."""
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
cb = AsyncMock()
|
||||
cx.set_disconnect_callback(cb)
|
||||
cx.transport = None
|
||||
|
||||
await cx.send(b"\x01")
|
||||
|
||||
cb.assert_awaited_once()
|
||||
reason = cb.call_args[0][0]
|
||||
assert reason == "serial_transport_lost"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serial_send_write_error_fires_disconnect():
|
||||
"""Serial: OSError during transport.write fires _disconnect_callback."""
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
cb = AsyncMock()
|
||||
cx.set_disconnect_callback(cb)
|
||||
|
||||
mock_transport = MagicMock()
|
||||
mock_transport.write.side_effect = OSError("Device not configured")
|
||||
cx.transport = mock_transport
|
||||
|
||||
await cx.send(b"\x01")
|
||||
|
||||
cb.assert_awaited_once()
|
||||
reason = cb.call_args[0][0]
|
||||
assert "serial_write_failed" in reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BLE send() fires disconnect on transport-lost and write error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ble_send_no_client_fires_disconnect():
|
||||
"""BLE: send() with no client fires _disconnect_callback."""
|
||||
# Can't import BLEConnection directly if bleak isn't installed,
|
||||
# so test via dynamic import with a guard.
|
||||
try:
|
||||
from meshcore.ble_cx import BLEConnection
|
||||
except ImportError:
|
||||
pytest.skip("bleak not installed")
|
||||
|
||||
# BLEConnection.__init__ checks BLEAK_AVAILABLE; patch it
|
||||
with patch("meshcore.ble_cx.BLEAK_AVAILABLE", True), \
|
||||
patch("meshcore.ble_cx.BleakClient", MagicMock()):
|
||||
cx = BLEConnection.__new__(BLEConnection)
|
||||
cx.client = None
|
||||
cx._user_provided_client = None
|
||||
cx._user_provided_address = None
|
||||
cx._user_provided_device = None
|
||||
cx.address = None
|
||||
cx.device = None
|
||||
cx.pin = None
|
||||
cx.rx_char = None
|
||||
cb = AsyncMock()
|
||||
cx._disconnect_callback = cb
|
||||
|
||||
result = await cx.send(b"\x01")
|
||||
|
||||
assert result is False
|
||||
cb.assert_awaited_once()
|
||||
reason = cb.call_args[0][0]
|
||||
assert reason == "ble_transport_lost"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serial connect() times out if connection_made never fires
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serial_connect_timeout():
|
||||
"""Serial: connect() raises TimeoutError if connection_made never fires."""
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
|
||||
# Mock create_serial_connection to do nothing (never fires connection_made)
|
||||
async def mock_create(*args, **kwargs):
|
||||
return (MagicMock(), MagicMock())
|
||||
|
||||
with patch("meshcore.serial_cx.serial_asyncio.create_serial_connection",
|
||||
side_effect=mock_create):
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await cx.connect(timeout=0.1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Oversize frame resets state and returns without dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_oversize_frame_empty_data_returns():
|
||||
"""TCP: oversize header with no trailing data returns without dispatch."""
|
||||
cx = TCPConnection("127.0.0.1", 5000)
|
||||
reader = RecordingReader()
|
||||
cx.set_reader(reader)
|
||||
|
||||
# Build a frame header with size > 300 and no payload data after header
|
||||
# Header: 0x3e + 2-byte LE size (e.g. 500 = 0x01F4)
|
||||
header = b"\x3e" + (500).to_bytes(2, "little")
|
||||
cx.handle_rx(header)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# No frames should be dispatched, and state should be reset
|
||||
assert reader.frames == []
|
||||
assert cx.header == b""
|
||||
assert cx.inframe == b""
|
||||
assert cx.frame_expected_size == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serial_oversize_frame_empty_data_returns():
|
||||
"""Serial: oversize header with no trailing data returns without dispatch."""
|
||||
cx = SerialConnection("/dev/null", 115200)
|
||||
reader = RecordingReader()
|
||||
cx.set_reader(reader)
|
||||
|
||||
header = b"\x3e" + (500).to_bytes(2, "little")
|
||||
cx.handle_rx(header)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert reader.frames == []
|
||||
assert cx.header == b""
|
||||
assert cx.inframe == b""
|
||||
assert cx.frame_expected_size == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TCP receive counter increments per MeshCore frame, not per TCP segment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_receive_count_per_frame_not_per_segment():
|
||||
"""TCP: _receive_count increments per completed frame, not per data_received call."""
|
||||
cx = TCPConnection("127.0.0.1", 5000)
|
||||
reader = RecordingReader()
|
||||
cx.set_reader(reader)
|
||||
cx._receive_count = 0
|
||||
|
||||
# Build a 4-byte payload frame
|
||||
payload = b"\xAA\xBB\xCC\xDD"
|
||||
frame = b"\x3e" + len(payload).to_bytes(2, "little") + payload
|
||||
|
||||
# Split the frame into 3 TCP segments (simulating fragmentation)
|
||||
protocol = TCPConnection.MCClientProtocol(cx)
|
||||
protocol.data_received(frame[:2]) # partial header
|
||||
protocol.data_received(frame[2:5]) # rest of header + 2 bytes payload
|
||||
protocol.data_received(frame[5:]) # remaining payload
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# 3 data_received calls but only 1 completed frame
|
||||
assert cx._receive_count == 1
|
||||
assert reader.frames == [payload]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tcp_multiple_frames_count_correctly():
|
||||
"""TCP: two complete frames in separate segments → _receive_count == 2."""
|
||||
cx = TCPConnection("127.0.0.1", 5000)
|
||||
reader = RecordingReader()
|
||||
cx.set_reader(reader)
|
||||
cx._receive_count = 0
|
||||
|
||||
payload1 = b"\x01\x02"
|
||||
frame1 = b"\x3e" + len(payload1).to_bytes(2, "little") + payload1
|
||||
payload2 = b"\x03\x04\x05"
|
||||
frame2 = b"\x3e" + len(payload2).to_bytes(2, "little") + payload2
|
||||
|
||||
protocol = TCPConnection.MCClientProtocol(cx)
|
||||
protocol.data_received(frame1)
|
||||
protocol.data_received(frame2)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert cx._receive_count == 2
|
||||
assert reader.frames == [payload1, payload2]
|
||||
Reference in New Issue
Block a user