kopia lustrzana https://github.com/Yakifo/amqtt
- Updated protocol handlers to more reliably remove active waiters when task cancellation occurs - Fixed checks where expecting a KeyError when it should be checking if not None - Updated next_packet_id property to correctly check if there are any packet_ids available. Avoids infinite loop if all packet ids are used.main
rodzic
1a2812c5fc
commit
d0eb64dc19
|
@ -106,17 +106,18 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
# Wait for SUBACK is received
|
||||
waiter = futures.Future()
|
||||
self._subscriptions_waiter[subscribe.variable_header.packet_id] = waiter
|
||||
return_codes = await waiter
|
||||
|
||||
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
|
||||
try:
|
||||
return_codes = await waiter
|
||||
finally:
|
||||
del self._subscriptions_waiter[subscribe.variable_header.packet_id]
|
||||
return return_codes
|
||||
|
||||
async def handle_suback(self, suback: SubackPacket):
|
||||
packet_id = suback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._subscriptions_waiter.get(packet_id)
|
||||
waiter = self._subscriptions_waiter.get(packet_id)
|
||||
if waiter is not None:
|
||||
waiter.set_result(suback.payload.return_codes)
|
||||
except KeyError:
|
||||
else:
|
||||
self.logger.warning(
|
||||
"Received SUBACK for unknown pending subscription with Id: %s"
|
||||
% packet_id
|
||||
|
@ -132,15 +133,17 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
await self._send_packet(unsubscribe)
|
||||
waiter = futures.Future()
|
||||
self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id] = waiter
|
||||
await waiter
|
||||
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
|
||||
try:
|
||||
await waiter
|
||||
finally:
|
||||
del self._unsubscriptions_waiter[unsubscribe.variable_header.packet_id]
|
||||
|
||||
async def handle_unsuback(self, unsuback: UnsubackPacket):
|
||||
packet_id = unsuback.variable_header.packet_id
|
||||
try:
|
||||
waiter = self._unsubscriptions_waiter.get(packet_id)
|
||||
waiter = self._unsubscriptions_waiter.get(packet_id)
|
||||
if waiter is not None:
|
||||
waiter.set_result(None)
|
||||
except KeyError:
|
||||
else:
|
||||
self.logger.warning(
|
||||
"Received UNSUBACK for unknown pending subscription with Id: %s"
|
||||
% packet_id
|
||||
|
@ -152,10 +155,12 @@ class ClientProtocolHandler(ProtocolHandler):
|
|||
|
||||
async def mqtt_ping(self):
|
||||
ping_packet = PingReqPacket()
|
||||
await self._send_packet(ping_packet)
|
||||
resp = await self._pingresp_queue.get()
|
||||
if self._ping_task:
|
||||
self._ping_task = None
|
||||
try:
|
||||
await self._send_packet(ping_packet)
|
||||
resp = await self._pingresp_queue.get()
|
||||
finally:
|
||||
if self._ping_task:
|
||||
self._ping_task = None
|
||||
return resp
|
||||
|
||||
async def handle_pingresp(self, pingresp: PingRespPacket):
|
||||
|
|
|
@ -293,12 +293,13 @@ class ProtocolHandler:
|
|||
# Wait for puback
|
||||
waiter = asyncio.Future()
|
||||
self._puback_waiters[app_message.packet_id] = waiter
|
||||
await waiter
|
||||
del self._puback_waiters[app_message.packet_id]
|
||||
app_message.puback_packet = waiter.result()
|
||||
|
||||
# Discard inflight message
|
||||
del self.session.inflight_out[app_message.packet_id]
|
||||
try:
|
||||
await waiter
|
||||
app_message.puback_packet = waiter.result()
|
||||
finally:
|
||||
self._puback_waiters.pop(app_message.packet_id, None)
|
||||
# Discard inflight message
|
||||
self.session.inflight_out.pop(app_message.packet_id, None)
|
||||
elif app_message.direction == INCOMING:
|
||||
# Initiate delivery
|
||||
self.logger.debug("Add message to delivery")
|
||||
|
@ -351,9 +352,12 @@ class ProtocolHandler:
|
|||
raise AMQTTException(message)
|
||||
waiter = asyncio.Future()
|
||||
self._pubrec_waiters[app_message.packet_id] = waiter
|
||||
await waiter
|
||||
del self._pubrec_waiters[app_message.packet_id]
|
||||
app_message.pubrec_packet = waiter.result()
|
||||
try:
|
||||
await waiter
|
||||
app_message.pubrec_packet = waiter.result()
|
||||
finally:
|
||||
self._pubrec_waiters.pop(app_message.packet_id, None)
|
||||
self.session.inflight_out.pop(app_message.packet_id, None)
|
||||
if not app_message.pubcomp_packet:
|
||||
# Send pubrel
|
||||
app_message.pubrel_packet = PubrelPacket.build(app_message.packet_id)
|
||||
|
@ -361,11 +365,12 @@ class ProtocolHandler:
|
|||
# Wait for PUBCOMP
|
||||
waiter = asyncio.Future()
|
||||
self._pubcomp_waiters[app_message.packet_id] = waiter
|
||||
await waiter
|
||||
del self._pubcomp_waiters[app_message.packet_id]
|
||||
app_message.pubcomp_packet = waiter.result()
|
||||
# Discard inflight message
|
||||
del self.session.inflight_out[app_message.packet_id]
|
||||
try:
|
||||
await waiter
|
||||
app_message.pubcomp_packet = waiter.result()
|
||||
finally:
|
||||
self._pubcomp_waiters.pop(app_message.packet_id, None)
|
||||
self.session.inflight_out.pop(app_message.packet_id, None)
|
||||
elif app_message.direction == INCOMING:
|
||||
self.session.inflight_in[app_message.packet_id] = app_message
|
||||
# Send pubrec
|
||||
|
|
|
@ -159,16 +159,15 @@ class Session:
|
|||
|
||||
@property
|
||||
def next_packet_id(self):
|
||||
self._packet_id += 1
|
||||
if self._packet_id > 65535:
|
||||
self._packet_id = 1
|
||||
self._packet_id = (self._packet_id % 65535) + 1
|
||||
limit = self._packet_id
|
||||
while (
|
||||
self._packet_id in self.inflight_in or self._packet_id in self.inflight_out
|
||||
):
|
||||
self._packet_id += 1
|
||||
if self._packet_id > 65535:
|
||||
self._packet_id = (self._packet_id % 65535) + 1
|
||||
if self._packet_id == limit:
|
||||
raise AMQTTException(
|
||||
"More than 65525 messages pending. No free packet ID"
|
||||
"More than 65535 messages pending. No free packet ID"
|
||||
)
|
||||
|
||||
return self._packet_id
|
||||
|
|
|
@ -218,3 +218,80 @@ async def test_deliver_timeout():
|
|||
await client.unsubscribe(["$SYS/broker/uptime"])
|
||||
await client.disconnect()
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_publish_qos1():
|
||||
"""
|
||||
Tests that timeouts on published messages will clean up in flight messages
|
||||
"""
|
||||
data = b"data"
|
||||
broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
|
||||
await broker.start()
|
||||
client_pub = MQTTClient()
|
||||
await client_pub.connect("mqtt://127.0.0.1/")
|
||||
assert client_pub.session.inflight_out_count == 0
|
||||
fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_1))
|
||||
assert len(client_pub._handler._puback_waiters) == 0
|
||||
while len(client_pub._handler._puback_waiters) == 0 or fut.done():
|
||||
await asyncio.sleep(0)
|
||||
assert len(client_pub._handler._puback_waiters) == 1
|
||||
assert client_pub.session.inflight_out_count == 1
|
||||
fut.cancel()
|
||||
await asyncio.wait([fut])
|
||||
assert len(client_pub._handler._puback_waiters) == 0
|
||||
assert client_pub.session.inflight_out_count == 0
|
||||
await client_pub.disconnect()
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_publish_qos2_pubrec():
|
||||
"""
|
||||
Tests that timeouts on published messages will clean up in flight messages
|
||||
"""
|
||||
data = b"data"
|
||||
broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
|
||||
await broker.start()
|
||||
client_pub = MQTTClient()
|
||||
await client_pub.connect("mqtt://127.0.0.1/")
|
||||
assert client_pub.session.inflight_out_count == 0
|
||||
fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2))
|
||||
assert len(client_pub._handler._pubrec_waiters) == 0
|
||||
while (
|
||||
len(client_pub._handler._pubrec_waiters) == 0 or fut.done() or fut.cancelled()
|
||||
):
|
||||
await asyncio.sleep(0)
|
||||
assert len(client_pub._handler._pubrec_waiters) == 1
|
||||
assert client_pub.session.inflight_out_count == 1
|
||||
fut.cancel()
|
||||
await asyncio.sleep(1)
|
||||
await asyncio.wait([fut])
|
||||
assert len(client_pub._handler._pubrec_waiters) == 0
|
||||
assert client_pub.session.inflight_out_count == 0
|
||||
await client_pub.disconnect()
|
||||
await broker.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_publish_qos2_pubcomp():
|
||||
"""
|
||||
Tests that timeouts on published messages will clean up in flight messages
|
||||
"""
|
||||
data = b"data"
|
||||
broker = Broker(broker_config, plugin_namespace="amqtt.test.plugins")
|
||||
await broker.start()
|
||||
client_pub = MQTTClient()
|
||||
await client_pub.connect("mqtt://127.0.0.1/")
|
||||
assert client_pub.session.inflight_out_count == 0
|
||||
fut = asyncio.create_task(client_pub.publish("test_topic", data, QOS_2))
|
||||
assert len(client_pub._handler._pubcomp_waiters) == 0
|
||||
while len(client_pub._handler._pubcomp_waiters) == 0 or fut.done():
|
||||
await asyncio.sleep(0)
|
||||
assert len(client_pub._handler._pubcomp_waiters) == 1
|
||||
fut.cancel()
|
||||
await asyncio.wait([fut])
|
||||
assert len(client_pub._handler._pubcomp_waiters) == 0
|
||||
assert client_pub.session.inflight_out_count == 0
|
||||
await client_pub.disconnect()
|
||||
await broker.shutdown()
|
||||
|
|
Ładowanie…
Reference in New Issue