- Updated protocol handlers to more reliably remove active waiters when task cancellation occurs
- Fixed checks where expecting a KeyError when it should be checking if not None
- Updated next_packet_id property to correctly check if there are any packet_ids available. Avoids infinite loop if all packet ids are used.
main
rparry-jones 2022-12-05 10:08:04 +11:00 zatwierdzone przez Florian Ludwig
rodzic 1a2812c5fc
commit d0eb64dc19
4 zmienionych plików z 121 dodań i 35 usunięć

Wyświetl plik

@ -106,17 +106,18 @@ class ClientProtocolHandler(ProtocolHandler):
# Wait for SUBACK is received
waiter = futures.Future()
self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
return_codes = await waiter
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
try:
return_codes = await waiter
finally:
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
return return_codes
async def handle_suback(self, suback: SubackPacket):
packet_id = suback.variable_header.packet_id
try:
waiter = self._subscriptions_waiter.get(packet_id)
waiter = self._subscriptions_waiter.get(packet_id)
if waiter is not None:
waiter.set_result(suback.payload.return_codes)
except KeyError:
else:
self.logger.warning(
"Received SUBACK for unknown pending subscription with Id: %s"
% packet_id
@ -132,15 +133,17 @@ class ClientProtocolHandler(ProtocolHandler):
await self._send_packet(unsubscribe)
waiter = futures.Future()
self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
await waiter
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
try:
await waiter
finally:
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
async def handle_unsuback(self, unsuback: UnsubackPacket):
packet_id = unsuback.variable_header.packet_id
try:
waiter = self._unsubscriptions_waiter.get(packet_id)
waiter = self._unsubscriptions_waiter.get(packet_id)
if waiter is not None:
waiter.set_result(None)
except KeyError:
else:
self.logger.warning(
"Received UNSUBACK for unknown pending subscription with Id: %s"
% packet_id
@ -152,10 +155,12 @@ class ClientProtocolHandler(ProtocolHandler):
async def mqtt_ping(self):
ping_packet = PingReqPacket()
await self._send_packet(ping_packet)
resp = await self._pingresp_queue.get()
if self._ping_task:
self._ping_task = None
try:
await self._send_packet(ping_packet)
resp = await self._pingresp_queue.get()
finally:
if self._ping_task:
self._ping_task = None
return resp
async def handle_pingresp(self, pingresp: PingRespPacket):

Wyświetl plik

@ -293,12 +293,13 @@ class ProtocolHandler:
# Wait for puback
waiter = asyncio.Future()
self._puback_waiters[app_message.packet_id] = waiter
await waiter
del self._puback_waiters[app_message.packet_id]
app_message.puback_packet = waiter.result()
# Discard inflight message
del self.session.inflight_out[app_message.packet_id]
try:
await waiter
app_message.puback_packet = waiter.result()
finally:
self._puback_waiters.pop(app_message.packet_id, None)
# Discard inflight message
self.session.inflight_out.pop(app_message.packet_id, None)
elif app_message.direction == INCOMING:
# Initiate delivery
self.logger.debug("Add message to delivery")
@ -351,9 +352,12 @@ class ProtocolHandler:
raise AMQTTException(message)
waiter = asyncio.Future()
self._pubrec_waiters[app_message.packet_id] = waiter
await waiter
del self._pubrec_waiters[app_message.packet_id]
app_message.pubrec_packet = waiter.result()
try:
await waiter
app_message.pubrec_packet = waiter.result()
finally:
self._pubrec_waiters.pop(app_message.packet_id, None)
self.session.inflight_out.pop(app_message.packet_id, None)
if not app_message.pubcomp_packet:
# Send pubrel
app_message.pubrel_packet = PubrelPacket.build(app_message.packet_id)
@ -361,11 +365,12 @@ class ProtocolHandler:
# Wait for PUBCOMP
waiter = asyncio.Future()
self._pubcomp_waiters[app_message.packet_id] = waiter
await waiter
del self._pubcomp_waiters[app_message.packet_id]
app_message.pubcomp_packet = waiter.result()
# Discard inflight message
del self.session.inflight_out[app_message.packet_id]
try:
await waiter
app_message.pubcomp_packet = waiter.result()
finally:
self._pubcomp_waiters.pop(app_message.packet_id, None)
self.session.inflight_out.pop(app_message.packet_id, None)
elif app_message.direction == INCOMING:
self.session.inflight_in[app_message.packet_id] = app_message
# Send pubrec

Wyświetl plik

@ -159,16 +159,15 @@ class Session:
@property
def next_packet_id(self):
self._packet_id += 1
if self._packet_id > 65535:
self._packet_id = 1
self._packet_id = (self._packet_id % 65535) + 1
limit = self._packet_id
while (
self._packet_id in self.inflight_in or self._packet_id in self.inflight_out
):
self._packet_id += 1
if self._packet_id > 65535:
self._packet_id = (self._packet_id % 65535) + 1
if self._packet_id == limit:
raise AMQTTException(
"More than 65525 messages pending. No free packet ID"
"More than 65535 messages pending. No free packet ID"
)
return self._packet_id

Wyświetl plik

@ -218,3 +218,80 @@ async def test_deliver_timeout():
await client.unsubscribe(["$SYS/broker/uptime"])
await client.disconnect()
await broker.shutdown()
@pytest.mark.asyncio
async def test_cancel_publish_qos1():
"""
Tests that timeouts on published messages will clean up in flight messages
"""
data = b"data"
broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
await broker.start()
client_pub = MQTTClient()
await client_pub.connect("mqtt://127.0.0.1/")
assert client_pub.session.inflight_out_count == 0
fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_1))
assert len(client_pub._handler._puback_waiters) == 0
while len(client_pub._handler._puback_waiters) == 0 or fut.done():
await asyncio.sleep(0)
assert len(client_pub._handler._puback_waiters) == 1
assert client_pub.session.inflight_out_count == 1
fut.cancel()
await asyncio.wait([fut])
assert len(client_pub._handler._puback_waiters) == 0
assert client_pub.session.inflight_out_count == 0
await client_pub.disconnect()
await broker.shutdown()
@pytest.mark.asyncio
async def test_cancel_publish_qos2_pubrec():
"""
Tests that timeouts on published messages will clean up in flight messages
"""
data = b"data"
broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
await broker.start()
client_pub = MQTTClient()
await client_pub.connect("mqtt://127.0.0.1/")
assert client_pub.session.inflight_out_count == 0
fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2))
assert len(client_pub._handler._pubrec_waiters) == 0
while (
len(client_pub._handler._pubrec_waiters) == 0 or fut.done() or fut.cancelled()
):
await asyncio.sleep(0)
assert len(client_pub._handler._pubrec_waiters) == 1
assert client_pub.session.inflight_out_count == 1
fut.cancel()
await asyncio.sleep(1)
await asyncio.wait([fut])
assert len(client_pub._handler._pubrec_waiters) == 0
assert client_pub.session.inflight_out_count == 0
await client_pub.disconnect()
await broker.shutdown()
@pytest.mark.asyncio
async def test_cancel_publish_qos2_pubcomp():
"""
Tests that timeouts on published messages will clean up in flight messages
"""
data = b"data"
broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
await broker.start()
client_pub = MQTTClient()
await client_pub.connect("mqtt://127.0.0.1/")
assert client_pub.session.inflight_out_count == 0
fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2))
assert len(client_pub._handler._pubcomp_waiters) == 0
while len(client_pub._handler._pubcomp_waiters) == 0 or fut.done():
await asyncio.sleep(0)
assert len(client_pub._handler._pubcomp_waiters) == 1
fut.cancel()
await asyncio.wait([fut])
assert len(client_pub._handler._pubcomp_waiters) == 0
assert client_pub.session.inflight_out_count == 0
await client_pub.disconnect()
await broker.shutdown()