mirror of
https://github.com/meshcore-dev/meshcore_py.git
synced 2026-06-11 11:56:18 +00:00
Merge pull request #77 from mwolter805/fix/transport-symmetry
fix: symmetric disconnect signaling, serial timeout, BLE callback, oversize-frame recovery
This commit is contained in:
@@ -124,9 +124,12 @@ class BLEConnection:
|
|||||||
await self.client.pair()
|
await self.client.pair()
|
||||||
logger.info("BLE pairing successful")
|
logger.info("BLE pairing successful")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"BLE pairing failed: {e}")
|
logger.error(f"BLE pairing failed: {e}")
|
||||||
# Don't fail the connection if pairing fails, as the device
|
# A failed pairing leaves the transport in a half-usable
|
||||||
# might already be paired or not require pairing
|
# state — re-raise so the caller gets a clean failure
|
||||||
|
# instead of a silently degraded connection.
|
||||||
|
await self.client.disconnect()
|
||||||
|
raise
|
||||||
|
|
||||||
except BleakDeviceNotFoundError:
|
except BleakDeviceNotFoundError:
|
||||||
return None
|
return None
|
||||||
@@ -162,6 +165,17 @@ class BLEConnection:
|
|||||||
self.client = self._user_provided_client
|
self.client = self._user_provided_client
|
||||||
self.device = self._user_provided_device
|
self.device = self._user_provided_device
|
||||||
|
|
||||||
|
# Re-register disconnect callback on the reset client so subsequent
|
||||||
|
# disconnects after a reconnect cycle are still detected.
|
||||||
|
if self.client is not None and hasattr(self.client, 'set_disconnected_callback'):
|
||||||
|
try:
|
||||||
|
self.client.set_disconnected_callback(self.handle_disconnect)
|
||||||
|
except Exception:
|
||||||
|
# set_disconnected_callback may not be available on all bleak
|
||||||
|
# versions; the next connect() call will re-create the client
|
||||||
|
# with the callback anyway.
|
||||||
|
pass
|
||||||
|
|
||||||
if self._disconnect_callback:
|
if self._disconnect_callback:
|
||||||
self._spawn_background(self._disconnect_callback("ble_disconnect"))
|
self._spawn_background(self._disconnect_callback("ble_disconnect"))
|
||||||
|
|
||||||
@@ -179,11 +193,19 @@ class BLEConnection:
|
|||||||
async def send(self, data):
|
async def send(self, data):
|
||||||
if not self.client:
|
if not self.client:
|
||||||
logger.error("Client is not connected")
|
logger.error("Client is not connected")
|
||||||
|
if self._disconnect_callback:
|
||||||
|
await self._disconnect_callback("ble_transport_lost")
|
||||||
return False
|
return False
|
||||||
if not self.rx_char:
|
if not self.rx_char:
|
||||||
logger.error("RX characteristic not found")
|
logger.error("RX characteristic not found")
|
||||||
return False
|
return False
|
||||||
|
try:
|
||||||
await self.client.write_gatt_char(self.rx_char, bytes(data), response=True)
|
await self.client.write_gatt_char(self.rx_char, bytes(data), response=True)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"BLE write failed: {exc}")
|
||||||
|
if self._disconnect_callback:
|
||||||
|
await self._disconnect_callback(f"ble_write_failed: {exc}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""Disconnect from the BLE device."""
|
"""Disconnect from the BLE device."""
|
||||||
|
|||||||
@@ -60,9 +60,13 @@ class SerialConnection:
|
|||||||
def resume_writing(self):
|
def resume_writing(self):
|
||||||
logger.debug("resume writing")
|
logger.debug("resume writing")
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self, timeout: float = 10.0):
|
||||||
"""
|
"""
|
||||||
Connects to the device
|
Connects to the device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum seconds to wait for connection_made callback.
|
||||||
|
Defaults to 10.0. Raises asyncio.TimeoutError on expiry.
|
||||||
"""
|
"""
|
||||||
self._connected_event.clear()
|
self._connected_event.clear()
|
||||||
|
|
||||||
@@ -74,7 +78,7 @@ class SerialConnection:
|
|||||||
baudrate=self.baudrate,
|
baudrate=self.baudrate,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._connected_event.wait()
|
await asyncio.wait_for(self._connected_event.wait(), timeout=timeout)
|
||||||
logger.info("Serial Connection started")
|
logger.info("Serial Connection started")
|
||||||
return self.port
|
return self.port
|
||||||
|
|
||||||
@@ -110,7 +114,7 @@ class SerialConnection:
|
|||||||
self.frame_expected_size = 0
|
self.frame_expected_size = 0
|
||||||
if len(data) > 0: # rerun handle_rx on remaining data
|
if len(data) > 0: # rerun handle_rx on remaining data
|
||||||
self.handle_rx(data)
|
self.handle_rx(data)
|
||||||
return
|
return # nothing left to process after reset
|
||||||
|
|
||||||
upbound = self.frame_expected_size - len(self.inframe)
|
upbound = self.frame_expected_size - len(self.inframe)
|
||||||
if len(data) < upbound:
|
if len(data) < upbound:
|
||||||
@@ -133,11 +137,18 @@ class SerialConnection:
|
|||||||
async def send(self, data):
|
async def send(self, data):
|
||||||
if not self.transport:
|
if not self.transport:
|
||||||
logger.error("Transport not connected, cannot send data")
|
logger.error("Transport not connected, cannot send data")
|
||||||
|
if self._disconnect_callback:
|
||||||
|
await self._disconnect_callback("serial_transport_lost")
|
||||||
return
|
return
|
||||||
size = len(data)
|
size = len(data)
|
||||||
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
||||||
logger.debug(f"sending pkt : {pkt}")
|
logger.debug(f"sending pkt : {pkt}")
|
||||||
|
try:
|
||||||
self.transport.write(pkt)
|
self.transport.write(pkt)
|
||||||
|
except OSError as exc:
|
||||||
|
logger.warning(f"Serial write failed: {exc}")
|
||||||
|
if self._disconnect_callback:
|
||||||
|
await self._disconnect_callback(f"serial_write_failed: {exc}")
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""Close the serial connection."""
|
"""Close the serial connection."""
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ class TCPConnection:
|
|||||||
|
|
||||||
def data_received(self, data):
|
def data_received(self, data):
|
||||||
logger.debug("data received")
|
logger.debug("data received")
|
||||||
self.cx._receive_count += 1
|
|
||||||
self.cx.handle_rx(data)
|
self.cx.handle_rx(data)
|
||||||
|
|
||||||
def error_received(self, exc):
|
def error_received(self, exc):
|
||||||
@@ -101,7 +100,7 @@ class TCPConnection:
|
|||||||
self.frame_expected_size = 0
|
self.frame_expected_size = 0
|
||||||
if len(data) > 0: # rerun handle_rx on remaining data
|
if len(data) > 0: # rerun handle_rx on remaining data
|
||||||
self.handle_rx(data)
|
self.handle_rx(data)
|
||||||
return
|
return # nothing left to process after reset
|
||||||
|
|
||||||
upbound = self.frame_expected_size - len(self.inframe)
|
upbound = self.frame_expected_size - len(self.inframe)
|
||||||
if len(data) < upbound :
|
if len(data) < upbound :
|
||||||
@@ -111,6 +110,10 @@ class TCPConnection:
|
|||||||
|
|
||||||
self.inframe = self.inframe + data[0:upbound]
|
self.inframe = self.inframe + data[0:upbound]
|
||||||
data = data[upbound:]
|
data = data[upbound:]
|
||||||
|
# Increment per completed MeshCore frame, not per TCP segment (N04).
|
||||||
|
# The threshold heuristic in send() compares _send_count vs
|
||||||
|
# _receive_count — counting per-segment skews it under fragmentation.
|
||||||
|
self._receive_count += 1
|
||||||
if self.reader is not None:
|
if self.reader is not None:
|
||||||
# feed meshcore reader
|
# feed meshcore reader
|
||||||
self._spawn_background(self.reader.handle_rx(self.inframe))
|
self._spawn_background(self.reader.handle_rx(self.inframe))
|
||||||
@@ -142,7 +145,12 @@ class TCPConnection:
|
|||||||
size = len(data)
|
size = len(data)
|
||||||
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
pkt = b"\x3c" + size.to_bytes(2, byteorder="little") + data
|
||||||
logger.debug(f"sending pkt : {pkt}")
|
logger.debug(f"sending pkt : {pkt}")
|
||||||
|
try:
|
||||||
self.transport.write(pkt)
|
self.transport.write(pkt)
|
||||||
|
except (OSError, ConnectionResetError) as exc:
|
||||||
|
logger.warning(f"TCP write failed: {exc}")
|
||||||
|
if self._disconnect_callback:
|
||||||
|
await self._disconnect_callback(f"tcp_write_failed: {exc}")
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""Close the TCP connection."""
|
"""Close the TCP connection."""
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class TestBLEPinPairing(unittest.TestCase):
|
|||||||
|
|
||||||
@patch("meshcore.ble_cx.BleakClient")
|
@patch("meshcore.ble_cx.BleakClient")
|
||||||
def test_ble_connection_with_pin_failed_pairing(self, mock_bleak_client):
|
def test_ble_connection_with_pin_failed_pairing(self, mock_bleak_client):
|
||||||
"""Test BLE connection with PIN when pairing fails but connection continues"""
|
"""Test BLE connection with PIN when pairing fails — re-raises (F17)."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_client_instance = self._get_mock_bleak_client()
|
mock_client_instance = self._get_mock_bleak_client()
|
||||||
mock_client_instance.pair = AsyncMock(side_effect=Exception("Pairing failed"))
|
mock_client_instance.pair = AsyncMock(side_effect=Exception("Pairing failed"))
|
||||||
@@ -47,17 +47,16 @@ class TestBLEPinPairing(unittest.TestCase):
|
|||||||
pin = "123456"
|
pin = "123456"
|
||||||
ble_conn = BLEConnection(address=address, pin=pin)
|
ble_conn = BLEConnection(address=address, pin=pin)
|
||||||
|
|
||||||
# Act
|
# Act & Assert — pairing failure now re-raises instead of being
|
||||||
result = asyncio.run(ble_conn.connect())
|
# swallowed, because a half-usable transport is worse than a clean
|
||||||
|
# failure (forensics finding F17).
|
||||||
# Assert
|
with self.assertRaises(Exception) as ctx:
|
||||||
|
asyncio.run(ble_conn.connect())
|
||||||
|
self.assertIn("Pairing failed", str(ctx.exception))
|
||||||
mock_client_instance.connect.assert_called_once()
|
mock_client_instance.connect.assert_called_once()
|
||||||
mock_client_instance.pair.assert_called_once()
|
mock_client_instance.pair.assert_called_once()
|
||||||
mock_client_instance.start_notify.assert_called_once_with(
|
# disconnect should be called to clean up the failed connection
|
||||||
UART_TX_CHAR_UUID, ble_conn.handle_rx
|
mock_client_instance.disconnect.assert_called_once()
|
||||||
)
|
|
||||||
# Connection should still succeed even if pairing fails
|
|
||||||
self.assertEqual(result, address)
|
|
||||||
|
|
||||||
@patch("meshcore.ble_cx.BleakClient")
|
@patch("meshcore.ble_cx.BleakClient")
|
||||||
def test_ble_connection_without_pin_no_pairing(self, mock_bleak_client):
|
def test_ble_connection_without_pin_no_pairing(self, mock_bleak_client):
|
||||||
|
|||||||
238
tests/unit/test_transport_symmetry.py
Normal file
238
tests/unit/test_transport_symmetry.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""
|
||||||
|
Verification tests for transport symmetry fixes.
|
||||||
|
|
||||||
|
Covers: send symmetry across transports, serial disconnect callback on
|
||||||
|
transport-lost, serial connect timeout, oversize-frame return, BLE
|
||||||
|
disconnect-callback re-registration, BLE pairing failure re-raise,
|
||||||
|
TCP counter per frame not per segment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from meshcore.tcp_cx import TCPConnection
|
||||||
|
from meshcore.serial_cx import SerialConnection
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RecordingReader:
|
||||||
|
"""Minimal reader mock that records dispatched frames."""
|
||||||
|
def __init__(self):
|
||||||
|
self.frames = []
|
||||||
|
|
||||||
|
async def handle_rx(self, data):
|
||||||
|
self.frames.append(bytes(data))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TCP send() wraps transport.write in try/except
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tcp_send_write_error_fires_disconnect():
|
||||||
|
"""TCP: OSError during transport.write fires _disconnect_callback."""
|
||||||
|
cx = TCPConnection("127.0.0.1", 5000)
|
||||||
|
cb = AsyncMock()
|
||||||
|
cx.set_disconnect_callback(cb)
|
||||||
|
|
||||||
|
mock_transport = MagicMock()
|
||||||
|
mock_transport.write.side_effect = OSError("Broken pipe")
|
||||||
|
cx.transport = mock_transport
|
||||||
|
cx._send_count = 0
|
||||||
|
cx._receive_count = 0
|
||||||
|
|
||||||
|
await cx.send(b"\x01\x02\x03")
|
||||||
|
|
||||||
|
cb.assert_awaited_once()
|
||||||
|
reason = cb.call_args[0][0]
|
||||||
|
assert "tcp_write_failed" in reason
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Serial send() fires disconnect on transport-lost and write error
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serial_send_no_transport_fires_disconnect():
|
||||||
|
"""Serial: send() on None transport fires _disconnect_callback ."""
|
||||||
|
cx = SerialConnection("/dev/null", 115200)
|
||||||
|
cb = AsyncMock()
|
||||||
|
cx.set_disconnect_callback(cb)
|
||||||
|
cx.transport = None
|
||||||
|
|
||||||
|
await cx.send(b"\x01")
|
||||||
|
|
||||||
|
cb.assert_awaited_once()
|
||||||
|
reason = cb.call_args[0][0]
|
||||||
|
assert reason == "serial_transport_lost"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serial_send_write_error_fires_disconnect():
|
||||||
|
"""Serial: OSError during transport.write fires _disconnect_callback."""
|
||||||
|
cx = SerialConnection("/dev/null", 115200)
|
||||||
|
cb = AsyncMock()
|
||||||
|
cx.set_disconnect_callback(cb)
|
||||||
|
|
||||||
|
mock_transport = MagicMock()
|
||||||
|
mock_transport.write.side_effect = OSError("Device not configured")
|
||||||
|
cx.transport = mock_transport
|
||||||
|
|
||||||
|
await cx.send(b"\x01")
|
||||||
|
|
||||||
|
cb.assert_awaited_once()
|
||||||
|
reason = cb.call_args[0][0]
|
||||||
|
assert "serial_write_failed" in reason
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BLE send() fires disconnect on transport-lost and write error
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ble_send_no_client_fires_disconnect():
|
||||||
|
"""BLE: send() with no client fires _disconnect_callback."""
|
||||||
|
# Can't import BLEConnection directly if bleak isn't installed,
|
||||||
|
# so test via dynamic import with a guard.
|
||||||
|
try:
|
||||||
|
from meshcore.ble_cx import BLEConnection
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("bleak not installed")
|
||||||
|
|
||||||
|
# BLEConnection.__init__ checks BLEAK_AVAILABLE; patch it
|
||||||
|
with patch("meshcore.ble_cx.BLEAK_AVAILABLE", True), \
|
||||||
|
patch("meshcore.ble_cx.BleakClient", MagicMock()):
|
||||||
|
cx = BLEConnection.__new__(BLEConnection)
|
||||||
|
cx.client = None
|
||||||
|
cx._user_provided_client = None
|
||||||
|
cx._user_provided_address = None
|
||||||
|
cx._user_provided_device = None
|
||||||
|
cx.address = None
|
||||||
|
cx.device = None
|
||||||
|
cx.pin = None
|
||||||
|
cx.rx_char = None
|
||||||
|
cb = AsyncMock()
|
||||||
|
cx._disconnect_callback = cb
|
||||||
|
|
||||||
|
result = await cx.send(b"\x01")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
cb.assert_awaited_once()
|
||||||
|
reason = cb.call_args[0][0]
|
||||||
|
assert reason == "ble_transport_lost"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Serial connect() times out if connection_made never fires
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serial_connect_timeout():
|
||||||
|
"""Serial: connect() raises TimeoutError if connection_made never fires."""
|
||||||
|
cx = SerialConnection("/dev/null", 115200)
|
||||||
|
|
||||||
|
# Mock create_serial_connection to do nothing (never fires connection_made)
|
||||||
|
async def mock_create(*args, **kwargs):
|
||||||
|
return (MagicMock(), MagicMock())
|
||||||
|
|
||||||
|
with patch("meshcore.serial_cx.serial_asyncio.create_serial_connection",
|
||||||
|
side_effect=mock_create):
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await cx.connect(timeout=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Oversize frame resets state and returns without dispatch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tcp_oversize_frame_empty_data_returns():
|
||||||
|
"""TCP: oversize header with no trailing data returns without dispatch."""
|
||||||
|
cx = TCPConnection("127.0.0.1", 5000)
|
||||||
|
reader = RecordingReader()
|
||||||
|
cx.set_reader(reader)
|
||||||
|
|
||||||
|
# Build a frame header with size > 300 and no payload data after header
|
||||||
|
# Header: 0x3e + 2-byte LE size (e.g. 500 = 0x01F4)
|
||||||
|
header = b"\x3e" + (500).to_bytes(2, "little")
|
||||||
|
cx.handle_rx(header)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# No frames should be dispatched, and state should be reset
|
||||||
|
assert reader.frames == []
|
||||||
|
assert cx.header == b""
|
||||||
|
assert cx.inframe == b""
|
||||||
|
assert cx.frame_expected_size == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_serial_oversize_frame_empty_data_returns():
|
||||||
|
"""Serial: oversize header with no trailing data returns without dispatch."""
|
||||||
|
cx = SerialConnection("/dev/null", 115200)
|
||||||
|
reader = RecordingReader()
|
||||||
|
cx.set_reader(reader)
|
||||||
|
|
||||||
|
header = b"\x3e" + (500).to_bytes(2, "little")
|
||||||
|
cx.handle_rx(header)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert reader.frames == []
|
||||||
|
assert cx.header == b""
|
||||||
|
assert cx.inframe == b""
|
||||||
|
assert cx.frame_expected_size == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TCP receive counter increments per MeshCore frame, not per TCP segment
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tcp_receive_count_per_frame_not_per_segment():
|
||||||
|
"""TCP: _receive_count increments per completed frame, not per data_received call."""
|
||||||
|
cx = TCPConnection("127.0.0.1", 5000)
|
||||||
|
reader = RecordingReader()
|
||||||
|
cx.set_reader(reader)
|
||||||
|
cx._receive_count = 0
|
||||||
|
|
||||||
|
# Build a 4-byte payload frame
|
||||||
|
payload = b"\xAA\xBB\xCC\xDD"
|
||||||
|
frame = b"\x3e" + len(payload).to_bytes(2, "little") + payload
|
||||||
|
|
||||||
|
# Split the frame into 3 TCP segments (simulating fragmentation)
|
||||||
|
protocol = TCPConnection.MCClientProtocol(cx)
|
||||||
|
protocol.data_received(frame[:2]) # partial header
|
||||||
|
protocol.data_received(frame[2:5]) # rest of header + 2 bytes payload
|
||||||
|
protocol.data_received(frame[5:]) # remaining payload
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# 3 data_received calls but only 1 completed frame
|
||||||
|
assert cx._receive_count == 1
|
||||||
|
assert reader.frames == [payload]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tcp_multiple_frames_count_correctly():
|
||||||
|
"""TCP: two complete frames in separate segments → _receive_count == 2."""
|
||||||
|
cx = TCPConnection("127.0.0.1", 5000)
|
||||||
|
reader = RecordingReader()
|
||||||
|
cx.set_reader(reader)
|
||||||
|
cx._receive_count = 0
|
||||||
|
|
||||||
|
payload1 = b"\x01\x02"
|
||||||
|
frame1 = b"\x3e" + len(payload1).to_bytes(2, "little") + payload1
|
||||||
|
payload2 = b"\x03\x04\x05"
|
||||||
|
frame2 = b"\x3e" + len(payload2).to_bytes(2, "little") + payload2
|
||||||
|
|
||||||
|
protocol = TCPConnection.MCClientProtocol(cx)
|
||||||
|
protocol.data_received(frame1)
|
||||||
|
protocol.data_received(frame2)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert cx._receive_count == 2
|
||||||
|
assert reader.frames == [payload1, payload2]
|
||||||
Reference in New Issue
Block a user