Merge pull request #70 from meshcore-dev/feature/mesh-request-lock

Add mesh request lock to serialize firmware-bound commands
This commit is contained in:
fdlamotte
2026-04-09 05:15:31 -04:00
committed by GitHub
7 changed files with 324 additions and 245 deletions

View File

@@ -64,6 +64,7 @@ class CommandHandlerBase:
self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None self._sender_func: Optional[Callable[[bytes], Coroutine[Any, Any, None]]] = None
self._reader: Optional[MessageReader] = None self._reader: Optional[MessageReader] = None
self.dispatcher: Optional[EventDispatcher] = None self.dispatcher: Optional[EventDispatcher] = None
self._mesh_request_lock = asyncio.Lock()
self.default_timeout = ( self.default_timeout = (
default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT default_timeout if default_timeout is not None else self.DEFAULT_TIMEOUT
) )

View File

@@ -19,6 +19,7 @@ class BinaryCommandHandler(CommandHandlerBase):
return await self.req_status_sync(contact, timeout, min_timeout) return await self.req_status_sync(contact, timeout, min_timeout)
async def req_status_sync(self, contact, timeout=0, min_timeout=0): async def req_status_sync(self, contact, timeout=0, min_timeout=0):
async with self._mesh_request_lock:
res = await self.send_binary_req( res = await self.send_binary_req(
contact, contact,
BinaryReqType.STATUS, BinaryReqType.STATUS,
@@ -48,6 +49,7 @@ class BinaryCommandHandler(CommandHandlerBase):
return await self.req_telemetry_sync(contact, timeout, min_timeout) return await self.req_telemetry_sync(contact, timeout, min_timeout)
async def req_telemetry_sync(self, contact, timeout=0, min_timeout=0): async def req_telemetry_sync(self, contact, timeout=0, min_timeout=0):
async with self._mesh_request_lock:
res = await self.send_binary_req( res = await self.send_binary_req(
contact, contact,
BinaryReqType.TELEMETRY, BinaryReqType.TELEMETRY,
@@ -63,7 +65,6 @@ class BinaryCommandHandler(CommandHandlerBase):
if self.dispatcher is None: if self.dispatcher is None:
return None return None
# Listen for TELEMETRY_RESPONSE event
telem_event = await self.dispatcher.wait_for_event( telem_event = await self.dispatcher.wait_for_event(
EventType.TELEMETRY_RESPONSE, EventType.TELEMETRY_RESPONSE,
attribute_filters={"tag": res.payload["expected_ack"].hex()}, attribute_filters={"tag": res.payload["expected_ack"].hex()},
@@ -77,6 +78,7 @@ class BinaryCommandHandler(CommandHandlerBase):
return await self.req_mma_sync(contact, start, end, timeout,min_timeout) return await self.req_mma_sync(contact, start, end, timeout,min_timeout)
async def req_mma_sync(self, contact, start, end, timeout=0,min_timeout=0): async def req_mma_sync(self, contact, start, end, timeout=0,min_timeout=0):
async with self._mesh_request_lock:
req = ( req = (
start.to_bytes(4, "little", signed=False) start.to_bytes(4, "little", signed=False)
+ end.to_bytes(4, "little", signed=False) + end.to_bytes(4, "little", signed=False)
@@ -97,7 +99,6 @@ class BinaryCommandHandler(CommandHandlerBase):
if self.dispatcher is None: if self.dispatcher is None:
return None return None
# Listen for MMA_RESPONSE
mma_event = await self.dispatcher.wait_for_event( mma_event = await self.dispatcher.wait_for_event(
EventType.MMA_RESPONSE, EventType.MMA_RESPONSE,
attribute_filters={"tag": res.payload["expected_ack"].hex()}, attribute_filters={"tag": res.payload["expected_ack"].hex()},
@@ -111,6 +112,7 @@ class BinaryCommandHandler(CommandHandlerBase):
return await self.req_acl_sync(contact, timeout, min_timeout) return await self.req_acl_sync(contact, timeout, min_timeout)
async def req_acl_sync(self, contact, timeout=0, min_timeout=0): async def req_acl_sync(self, contact, timeout=0, min_timeout=0):
async with self._mesh_request_lock:
req = b"\0\0" req = b"\0\0"
res = await self.send_binary_req( res = await self.send_binary_req(
contact, contact,
@@ -127,7 +129,6 @@ class BinaryCommandHandler(CommandHandlerBase):
if self.dispatcher is None: if self.dispatcher is None:
return None return None
# Listen for ACL_RESPONSE event with matching tag
acl_event = await self.dispatcher.wait_for_event( acl_event = await self.dispatcher.wait_for_event(
EventType.ACL_RESPONSE, EventType.ACL_RESPONSE,
attribute_filters={"tag": res.payload["expected_ack"].hex()}, attribute_filters={"tag": res.payload["expected_ack"].hex()},
@@ -172,7 +173,7 @@ class BinaryCommandHandler(CommandHandlerBase):
timeout=0, timeout=0,
min_timeout=0 min_timeout=0
): ):
async with self._mesh_request_lock:
res = await self.req_neighbours_async(contact, res = await self.req_neighbours_async(contact,
count=count, count=count,
offset=offset, offset=offset,
@@ -190,7 +191,6 @@ class BinaryCommandHandler(CommandHandlerBase):
if self.dispatcher is None: if self.dispatcher is None:
return None return None
# Listen for NEIGHBOUR_RESPONSE
neighbours_event = await self.dispatcher.wait_for_event( neighbours_event = await self.dispatcher.wait_for_event(
EventType.NEIGHBOURS_RESPONSE, EventType.NEIGHBOURS_RESPONSE,
attribute_filters={"tag": res.payload["expected_ack"].hex()}, attribute_filters={"tag": res.payload["expected_ack"].hex()},
@@ -259,6 +259,7 @@ class BinaryCommandHandler(CommandHandlerBase):
) )
async def req_regions_sync(self, contact, timeout=0, min_timeout=0): async def req_regions_sync(self, contact, timeout=0, min_timeout=0):
async with self._mesh_request_lock:
res = await self.req_regions_async(contact, timeout, min_timeout) res = await self.req_regions_async(contact, timeout, min_timeout)
if res.type == EventType.ERROR: if res.type == EventType.ERROR:
@@ -294,7 +295,7 @@ class BinaryCommandHandler(CommandHandlerBase):
) )
async def req_owner_sync(self, contact, timeout=0, min_timeout=0): async def req_owner_sync(self, contact, timeout=0, min_timeout=0):
async with self._mesh_request_lock:
res = await self.req_owner_async(contact, timeout, min_timeout) res = await self.req_owner_async(contact, timeout, min_timeout)
if res.type == EventType.ERROR: if res.type == EventType.ERROR:
@@ -332,7 +333,7 @@ class BinaryCommandHandler(CommandHandlerBase):
) )
async def req_basic_sync(self, contact, timeout=0, min_timeout=0): async def req_basic_sync(self, contact, timeout=0, min_timeout=0):
async with self._mesh_request_lock:
res = await self.req_basic_async(contact, timeout, min_timeout) res = await self.req_basic_async(contact, timeout, min_timeout)
if res.type == EventType.ERROR: if res.type == EventType.ERROR:

View File

@@ -24,18 +24,37 @@ class MessagingCommands(CommandHandlerBase):
timeout, timeout,
) )
async def send_login(self, dst: DestinationType, pwd: str) -> Event: async def _send_login_raw(self, dst: DestinationType, pwd: str) -> Event:
dst_bytes = _validate_destination(dst, prefix_length=32) dst_bytes = _validate_destination(dst, prefix_length=32)
logger.debug(f"Sending login request to: {dst_bytes.hex()}") logger.debug(f"Sending login request to: {dst_bytes.hex()}")
data = b"\x1a" + dst_bytes + pwd.encode("utf-8") data = b"\x1a" + dst_bytes + pwd.encode("utf-8")
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
async def send_login(self, dst: DestinationType, pwd: str) -> Event:
logger.warning("*** please consider using send_login_sync instead of send_login")
return await self._send_login_raw(dst, pwd)
async def send_login_sync(self, dst: DestinationType, pwd: str, timeout=0, min_timeout=0) -> Optional[Event]:
"""Send login to a remote node and wait for the response."""
async with self._mesh_request_lock:
result = await self._send_login_raw(dst, pwd)
if result is None or result.type == EventType.ERROR:
return None
timeout = result.payload["suggested_timeout"] / 800 if timeout == 0 else timeout
timeout = timeout if timeout > min_timeout else min_timeout
login_event = await self.dispatcher.wait_for_event(
EventType.LOGIN_SUCCESS,
timeout=timeout,
)
return login_event
async def send_logout(self, dst: DestinationType) -> Event: async def send_logout(self, dst: DestinationType) -> Event:
dst_bytes = _validate_destination(dst, prefix_length=32) dst_bytes = _validate_destination(dst, prefix_length=32)
data = b"\x1d" + dst_bytes data = b"\x1d" + dst_bytes
return await self.send(data, [EventType.OK, EventType.ERROR]) return await self.send(data, [EventType.OK, EventType.ERROR])
async def send_statusreq(self, dst: DestinationType) -> Event: async def send_statusreq(self, dst: DestinationType) -> Event:
logger.warning("*** please consider using req_status_sync instead of send_statusreq")
dst_bytes = _validate_destination(dst, prefix_length=32) dst_bytes = _validate_destination(dst, prefix_length=32)
logger.debug(f"Sending status request to: {dst_bytes.hex()}") logger.debug(f"Sending status request to: {dst_bytes.hex()}")
data = b"\x1b" + dst_bytes data = b"\x1b" + dst_bytes
@@ -106,7 +125,7 @@ class MessagingCommands(CommandHandlerBase):
# if we have a full key (meaning we can reset path) consider direct # if we have a full key (meaning we can reset path) consider direct
# else consider flood # else consider flood
flood = len(dst_bytes) < 32 flood = len(dst_bytes) < 32
logger.info(f"send_msg_with_retry: can't determine if flood, assume {flood}") logger.info(f"send_msg_with_retry: can't determine if flood, assume {flood}")
res = None res = None
while attempts < max_attempts and res is None \ while attempts < max_attempts and res is None \
and (not flood or flood_attempts < max_flood_attempts): and (not flood or flood_attempts < max_flood_attempts):
@@ -122,7 +141,7 @@ class MessagingCommands(CommandHandlerBase):
contact["out_path_len"] = -1 contact["out_path_len"] = -1
if attempts > 0: if attempts > 0:
logger.info(f"Retry sending msg: {attempts + 1}") logger.info(f"Retry sending msg: {attempts + 1}")
result = await self.send_msg(dst, msg, timestamp, attempt=attempts) result = await self.send_msg(dst, msg, timestamp, attempt=attempts)
if result.type == EventType.ERROR: if result.type == EventType.ERROR:
@@ -166,17 +185,36 @@ class MessagingCommands(CommandHandlerBase):
return await self.send(data, [EventType.OK, EventType.ERROR]) return await self.send(data, [EventType.OK, EventType.ERROR])
async def send_telemetry_req(self, dst: DestinationType) -> Event: async def send_telemetry_req(self, dst: DestinationType) -> Event:
logger.warning("*** please consider using req_telemetry_sync instead of send_telemetry_req")
dst_bytes = _validate_destination(dst, prefix_length=32) dst_bytes = _validate_destination(dst, prefix_length=32)
logger.debug(f"Asking telemetry to {dst_bytes.hex()}") logger.debug(f"Asking telemetry to {dst_bytes.hex()}")
data = b"\x27\x00\x00\x00" + dst_bytes data = b"\x27\x00\x00\x00" + dst_bytes
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
async def send_path_discovery(self, dst: DestinationType) -> Event: async def _send_path_discovery_raw(self, dst: DestinationType) -> Event:
dst_bytes = _validate_destination(dst, prefix_length=32) dst_bytes = _validate_destination(dst, prefix_length=32)
logger.debug(f"Path discovery request for {dst_bytes.hex()}") logger.debug(f"Path discovery request for {dst_bytes.hex()}")
data = b"\x34\x00" + dst_bytes data = b"\x34\x00" + dst_bytes
return await self.send(data, [EventType.MSG_SENT, EventType.ERROR]) return await self.send(data, [EventType.MSG_SENT, EventType.ERROR])
async def send_path_discovery(self, dst: DestinationType) -> Event:
logger.warning("*** please consider using send_path_discovery_sync instead of send_path_discovery")
return await self._send_path_discovery_raw(dst)
async def send_path_discovery_sync(self, dst: DestinationType, timeout=0, min_timeout=0) -> Optional[Event]:
"""Send path discovery request and wait for the response."""
async with self._mesh_request_lock:
result = await self._send_path_discovery_raw(dst)
if result is None or result.type == EventType.ERROR:
return None
timeout = result.payload["suggested_timeout"] / 800 if timeout == 0 else timeout
timeout = timeout if timeout > min_timeout else min_timeout
path_event = await self.dispatcher.wait_for_event(
EventType.PATH_RESPONSE,
timeout=timeout,
)
return path_event
async def send_trace( async def send_trace(
self, self,
auth_code: int = 0, auth_code: int = 0,

View File

@@ -2,10 +2,12 @@ import pytest
import asyncio import asyncio
from unittest.mock import MagicMock, AsyncMock from unittest.mock import MagicMock, AsyncMock
from meshcore.commands import CommandHandler from meshcore.commands import CommandHandler
from meshcore.events import EventType, Event from meshcore.events import EventType, Event, Subscription
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
VALID_PUBKEY_HEX = "0123456789abcdef" * 4 # 64 hex chars = 32 bytes
# Fixtures # Fixtures
@pytest.fixture @pytest.fixture
@@ -20,6 +22,15 @@ def mock_dispatcher():
dispatcher = MagicMock() dispatcher = MagicMock()
dispatcher.wait_for_event = AsyncMock() dispatcher.wait_for_event = AsyncMock()
dispatcher.dispatch = AsyncMock() dispatcher.dispatch = AsyncMock()
def fake_subscribe(event_type, handler, attribute_filters=None):
sub = MagicMock(spec=Subscription)
sub.unsubscribe = MagicMock()
dispatcher._last_subscribe_handler = handler
dispatcher._last_subscribe_event_type = event_type
return sub
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
return dispatcher return dispatcher
@@ -36,20 +47,17 @@ def command_handler(mock_connection, mock_dispatcher):
return handler return handler
# Test helper
def setup_event_response(mock_dispatcher, event_type, payload, attribute_filters=None): def setup_event_response(mock_dispatcher, event_type, payload, attribute_filters=None):
async def wait_response(requested_type, filters=None, timeout=None): def fake_subscribe(evt_type, handler, attr_filters=None):
if requested_type == event_type: sub = MagicMock(spec=Subscription)
if filters and attribute_filters: sub.unsubscribe = MagicMock()
if not all( if evt_type == event_type:
attribute_filters.get(key) == value asyncio.get_event_loop().call_soon(
for key, value in filters.items() handler, Event(event_type, payload)
): )
return None return sub
return Event(event_type, payload)
return None
mock_dispatcher.wait_for_event.side_effect = wait_response mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
# Basic tests # Basic tests
@@ -72,11 +80,9 @@ async def test_send_with_event(command_handler, mock_connection, mock_dispatcher
async def test_send_timeout(command_handler, mock_connection, mock_dispatcher): async def test_send_timeout(command_handler, mock_connection, mock_dispatcher):
mock_dispatcher.wait_for_event.side_effect = asyncio.TimeoutError
result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1) result = await command_handler.send(b"test_command", [EventType.OK], timeout=0.1)
assert result.type == EventType.ERROR assert result.type == EventType.ERROR
assert result.payload == {"reason": "timeout"} assert result.payload == {"reason": "no_event_received"}
# Destination validation tests # Destination validation tests
@@ -106,7 +112,7 @@ async def test_validate_destination_contact_object(command_handler, mock_connect
# Command tests # Command tests
async def test_send_login(command_handler, mock_connection): async def test_send_login(command_handler, mock_connection):
await command_handler.send_login("0123456789abcdef", "password") await command_handler.send_login(VALID_PUBKEY_HEX, "password")
assert mock_connection.send.call_args[0][0].startswith(b"\x1a") assert mock_connection.send.call_args[0][0].startswith(b"\x1a")
assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0]
@@ -198,15 +204,14 @@ async def test_get_contacts(command_handler, mock_connection):
async def test_reset_path(command_handler, mock_connection): async def test_reset_path(command_handler, mock_connection):
dst = "0123456789abcdef" command_handler._get_contact_by_prefix = lambda prefix: None
await command_handler.reset_path(dst) await command_handler.reset_path(VALID_PUBKEY_HEX)
assert mock_connection.send.call_args[0][0].startswith(b"\x0d") assert mock_connection.send.call_args[0][0].startswith(b"\x0d")
assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0]
async def test_share_contact(command_handler, mock_connection): async def test_share_contact(command_handler, mock_connection):
dst = "0123456789abcdef" await command_handler.share_contact(VALID_PUBKEY_HEX)
await command_handler.share_contact(dst)
assert mock_connection.send.call_args[0][0].startswith(b"\x10") assert mock_connection.send.call_args[0][0].startswith(b"\x10")
assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0]
@@ -218,15 +223,13 @@ async def test_export_contact(command_handler, mock_connection):
# Test exporting specific contact # Test exporting specific contact
mock_connection.reset_mock() mock_connection.reset_mock()
dst = "0123456789abcdef" await command_handler.export_contact(VALID_PUBKEY_HEX)
await command_handler.export_contact(dst)
assert mock_connection.send.call_args[0][0].startswith(b"\x11") assert mock_connection.send.call_args[0][0].startswith(b"\x11")
assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0]
async def test_remove_contact(command_handler, mock_connection): async def test_remove_contact(command_handler, mock_connection):
dst = "0123456789abcdef" await command_handler.remove_contact(VALID_PUBKEY_HEX)
await command_handler.remove_contact(dst)
assert mock_connection.send.call_args[0][0].startswith(b"\x0f") assert mock_connection.send.call_args[0][0].startswith(b"\x0f")
assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0]
@@ -242,15 +245,13 @@ async def test_get_msg(command_handler, mock_connection):
async def test_send_logout(command_handler, mock_connection): async def test_send_logout(command_handler, mock_connection):
dst = "0123456789abcdef" await command_handler.send_logout(VALID_PUBKEY_HEX)
await command_handler.send_logout(dst)
assert mock_connection.send.call_args[0][0].startswith(b"\x1d") assert mock_connection.send.call_args[0][0].startswith(b"\x1d")
assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0]
async def test_send_statusreq(command_handler, mock_connection): async def test_send_statusreq(command_handler, mock_connection):
dst = "0123456789abcdef" await command_handler.send_statusreq(VALID_PUBKEY_HEX)
await command_handler.send_statusreq(dst)
assert mock_connection.send.call_args[0][0].startswith(b"\x1b") assert mock_connection.send.call_args[0][0].startswith(b"\x1b")
assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0] assert b"\x01\x23\x45\x67\x89\xab" in mock_connection.send.call_args[0][0]
@@ -261,10 +262,10 @@ async def test_send_trace(command_handler, mock_connection):
first_call = mock_connection.send.call_args[0][0] first_call = mock_connection.send.call_args[0][0]
assert first_call.startswith(b"\x24") # 36 in decimal = 0x24 in hex assert first_call.startswith(b"\x24") # 36 in decimal = 0x24 in hex
# Test with all parameters # Test with all parameters (flags=1 means path_hash_len=2, so 4 hex chars each)
mock_connection.reset_mock() mock_connection.reset_mock()
await command_handler.send_trace( await command_handler.send_trace(
auth_code=12345, tag=67890, flags=1, path="01,23,45" auth_code=12345, tag=67890, flags=1, path="0123,2345,4567"
) )
second_call = mock_connection.send.call_args[0][0] second_call = mock_connection.send.call_args[0][0]
assert second_call.startswith(b"\x24") assert second_call.startswith(b"\x24")
@@ -273,25 +274,14 @@ async def test_send_trace(command_handler, mock_connection):
async def test_send_with_multiple_expected_events_returns_first_completed( async def test_send_with_multiple_expected_events_returns_first_completed(
command_handler, mock_connection, mock_dispatcher command_handler, mock_connection, mock_dispatcher
): ):
# Setup the dispatcher to return an ERROR event
error_payload = {"reason": "command_failed"} error_payload = {"reason": "command_failed"}
setup_event_response(mock_dispatcher, EventType.ERROR, error_payload)
async def simulate_error_event(*args, **kwargs):
# Simulate an ERROR event being returned
return Event(EventType.ERROR, error_payload)
# Patch the wait_for_event method to return our simulated event
mock_dispatcher.wait_for_event.side_effect = simulate_error_event
# Call send with both OK and ERROR in the expected_events list, with OK first
result = await command_handler.send( result = await command_handler.send(
b"test_command", [EventType.OK, EventType.ERROR] b"test_command", [EventType.OK, EventType.ERROR]
) )
# Verify the command was sent
mock_connection.send.assert_called_once_with(b"test_command") mock_connection.send.assert_called_once_with(b"test_command")
# Verify that even though OK was listed first, the ERROR event was returned
assert result.type == EventType.ERROR assert result.type == EventType.ERROR
assert result.payload == error_payload assert result.payload == error_payload

View File

@@ -0,0 +1,39 @@
"""Tests for LPP parsing to verify current values are handled correctly."""
import json
import pytest
from cayennelpp import LppFrame, LppData
class TestLppCurrentParsing:
"""Tests to verify LPP current values pass through correctly."""
def test_large_current_value_wraps_signed(self):
"""
When raw bytes represent a large unsigned value (like 65525),
values above 32.767 are reinterpreted as signed (negative).
65.525 - 65.536 = -0.011
"""
from meshcore.lpp_json_encoder import lpp_json_encoder
# Channel 2, Type 117 (current), Value 65525 raw = 0xFF 0xF5 (big-endian)
raw_bytes = bytes([2, 117, 0xFF, 0xF5])
lppdata = LppData.from_bytes(raw_bytes)
lpp = json.loads(json.dumps(LppFrame([lppdata]), default=lpp_json_encoder))
assert len(lpp) == 1
assert lpp[0]['channel'] == 2
assert lpp[0]['type'] == 'current'
assert lpp[0]['value'] == -0.011
def test_normal_positive_current(self):
"""Normal positive current should work correctly."""
from meshcore.lpp_json_encoder import lpp_json_encoder
# Channel 2, Type 117 (current), Value 500 raw = 0x01 0xF4 (big-endian)
raw_bytes = bytes([2, 117, 0x01, 0xF4])
lppdata = LppData.from_bytes(raw_bytes)
lpp = json.loads(json.dumps(LppFrame([lppdata]), default=lpp_json_encoder))
assert lpp[0]['value'] == 0.5

View File

@@ -7,13 +7,12 @@ import pytest
import asyncio import asyncio
from unittest.mock import MagicMock, AsyncMock from unittest.mock import MagicMock, AsyncMock
from meshcore.commands import CommandHandler from meshcore.commands import CommandHandler
from meshcore.events import Event, EventType from meshcore.events import Event, EventType, Subscription
from meshcore.reader import MessageReader from meshcore.reader import MessageReader
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
# Fixtures (consistent with existing test patterns)
@pytest.fixture @pytest.fixture
def mock_connection(): def mock_connection():
connection = MagicMock() connection = MagicMock()
@@ -26,6 +25,13 @@ def mock_dispatcher():
dispatcher = MagicMock() dispatcher = MagicMock()
dispatcher.wait_for_event = AsyncMock() dispatcher.wait_for_event = AsyncMock()
dispatcher.dispatch = AsyncMock() dispatcher.dispatch = AsyncMock()
def fake_subscribe(event_type, handler, attribute_filters=None):
sub = MagicMock(spec=Subscription)
sub.unsubscribe = MagicMock()
return sub
dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
return dispatcher return dispatcher
@@ -41,14 +47,17 @@ def command_handler(mock_connection, mock_dispatcher):
return handler return handler
# Test helper (consistent with existing patterns)
def setup_event_response(mock_dispatcher, event_type, payload): def setup_event_response(mock_dispatcher, event_type, payload):
async def wait_response(requested_type, filters=None, timeout=None): def fake_subscribe(evt_type, handler, attr_filters=None):
if requested_type == event_type: sub = MagicMock(spec=Subscription)
return Event(event_type, payload) sub.unsubscribe = MagicMock()
return None if evt_type == event_type:
asyncio.get_event_loop().call_soon(
handler, Event(event_type, payload)
)
return sub
mock_dispatcher.wait_for_event.side_effect = wait_response mock_dispatcher.subscribe = MagicMock(side_effect=fake_subscribe)
# Command tests # Command tests

View File

@@ -28,8 +28,9 @@ async def test_binary_response():
# Register the binary request first # Register the binary request first
tag = "417db968" tag = "417db968"
from meshcore.parsing import BinaryReqType from meshcore.packets import BinaryReqType
reader.register_binary_request(tag, BinaryReqType.ACL, 10.0) pubkey_prefix = "993acd42fc77"
reader.register_binary_request(pubkey_prefix, tag, BinaryReqType.ACL, 10.0)
print(f"Registered ACL request with tag {tag}") print(f"Registered ACL request with tag {tag}")
await reader.handle_rx(packet_data) await reader.handle_rx(packet_data)
@@ -64,7 +65,7 @@ async def test_binary_response():
print(f"Request type in response: 0x{request_type:02x} ({request_type})") print(f"Request type in response: 0x{request_type:02x} ({request_type})")
# Map request types to expected events # Map request types to expected events
from meshcore.parsing import BinaryReqType from meshcore.packets import BinaryReqType
if request_type == BinaryReqType.STATUS.value: if request_type == BinaryReqType.STATUS.value:
expected_event = EventType.STATUS_RESPONSE expected_event = EventType.STATUS_RESPONSE
elif request_type == BinaryReqType.TELEMETRY.value: elif request_type == BinaryReqType.TELEMETRY.value: