diff --git a/Examples/Channel.py b/Examples/Channel.py index f64e427..53b878c 100644 --- a/Examples/Channel.py +++ b/Examples/Channel.py @@ -1,6 +1,6 @@ ########################################################## # This RNS example demonstrates how to set up a link to # -# a destination, and pass structuredmessages over it # +# a destination, and pass structured messages over it # # using a channel. # ########################################################## @@ -46,7 +46,7 @@ class StringMessage(RNS.MessageBase): # message arrives over the channel. # # MSGTYPE must be unique across all message types we - # register with the channel. MSGTYPEs >= 0xff00 are + # register with the channel. MSGTYPEs >= 0xf000 are # reserved for the system. MSGTYPE = 0x0101 @@ -159,17 +159,36 @@ def client_disconnected(link): RNS.log("Client disconnected") def server_message_received(message): + """ + A message handler + @param message: An instance of a subclass of MessageBase + @return: True if message was handled + """ global latest_client_link - # When a message is received over any active link, # the replies will all be directed to the last client # that connected. + + # In a message handler, any deserializable message + # that arrives over the link's channel will be passed + # to all message handlers, unless a preceding handler indicates it + # has handled the message. + # + # if isinstance(message, StringMessage): RNS.log("Received data on the link: " + message.data + " (message created at " + str(message.timestamp) + ")") reply_message = StringMessage("I received \""+message.data+"\" over the link") latest_client_link.get_channel().send(reply_message) + # Incoming messages are sent to each message + # handler added to the channel, in the order they + # were added. + # If any message handler returns True, the message + # is considered handled and any subsequent + # handlers are skipped. + return True + ########################################################## #### Client Part ######################################### diff --git a/RNS/Channel.py b/RNS/Channel.py index 31aaf94..f6aff67 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -14,6 +14,13 @@ TPacket = TypeVar("TPacket") class ChannelOutletBase(ABC, Generic[TPacket]): + """ + An abstract transport layer interface used by Channel. + + DEPRECATED: This was created for testing; eventually + Channel will use Link or a LinkBase interface + directly. + """ @abstractmethod def send(self, raw: bytes) -> TPacket: raise NotImplemented() @@ -64,6 +71,9 @@ class ChannelOutletBase(ABC, Generic[TPacket]): class CEType(enum.IntEnum): + """ + ChannelException type codes + """ ME_NO_MSG_TYPE = 0 ME_INVALID_MSG_TYPE = 1 ME_NOT_REGISTERED = 2 @@ -73,12 +83,18 @@ class CEType(enum.IntEnum): class ChannelException(Exception): + """ + An exception thrown by Channel, with a type code. + """ def __init__(self, ce_type: CEType, *args): super().__init__(args) self.type = ce_type class MessageState(enum.IntEnum): + """ + Set of possible states for a Message + """ MSGSTATE_NEW = 0 MSGSTATE_SENT = 1 MSGSTATE_DELIVERED = 2 @@ -86,14 +102,29 @@ class MessageState(enum.IntEnum): class MessageBase(abc.ABC): + """ + Base type for any messages sent or received on a Channel. + Subclasses must define the two abstract methods as well as + the MSGTYPE class variable. + """ + # MSGTYPE must be unique within all classes sent over a + # channel. Additionally, MSGTYPE > 0xf000 are reserved. MSGTYPE = None @abstractmethod def pack(self) -> bytes: + """ + Create and return the binary representation of the message + @return: binary representation of message + """ raise NotImplemented() @abstractmethod def unpack(self, raw): + """ + Populate message from binary representation + @param raw: binary representation + """ raise NotImplemented() @@ -101,6 +132,10 @@ MessageCallbackType = NewType("MessageCallbackType", Callable[[MessageBase], boo class Envelope: + """ + Internal wrapper used to transport messages over a channel and + track its state within the channel framework. + """ def unpack(self, message_factories: dict[int, Type]) -> MessageBase: msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6]) raw = self.raw[6:] @@ -131,6 +166,12 @@ class Envelope: class Channel(contextlib.AbstractContextManager): + """ + Channel provides reliable delivery of messages over + a link. Channel is not meant to be instantiated + directly, but rather obtained from a Link using the + get_channel() function. + """ def __init__(self, outlet: ChannelOutletBase): self._outlet = outlet self._lock = threading.RLock() @@ -146,10 +187,14 @@ class Channel(contextlib.AbstractContextManager): def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None, __traceback: TracebackType | None) -> bool | None: - self.shutdown() + self._shutdown() return False def register_message_type(self, message_class: Type[MessageBase], *, is_system_type: bool = False): + """ + Register a message class for reception over a channel. + @param message_class: Class to register. Must extend MessageBase. + """ with self._lock: if not issubclass(message_class, MessageBase): raise ChannelException(CEType.ME_INVALID_MSG_TYPE, @@ -157,7 +202,7 @@ class Channel(contextlib.AbstractContextManager): if message_class.MSGTYPE is None: raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.") - if message_class.MSGTYPE >= 0xff00 and not is_system_type: + if message_class.MSGTYPE >= 0xf000 and not is_system_type: raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has system-reserved message type.") try: @@ -169,20 +214,34 @@ class Channel(contextlib.AbstractContextManager): self._message_factories[message_class.MSGTYPE] = message_class def add_message_handler(self, callback: MessageCallbackType): + """ + Add a handler for incoming messages. A handler + has the signature (message: MessageBase) -> bool. + Handlers are processed in the order they are + added. If any handler returns True, processing + of the message stops; handlers after the + returning handler will not be called. + @param callback: Function to call + @return: + """ with self._lock: if callback not in self._message_callbacks: self._message_callbacks.append(callback) def remove_message_handler(self, callback: MessageCallbackType): + """ + Remove a handler + @param callback: handler to remove + """ with self._lock: self._message_callbacks.remove(callback) - def shutdown(self): + def _shutdown(self): with self._lock: self._message_callbacks.clear() - self.clear_rings() + self._clear_rings() - def clear_rings(self): + def _clear_rings(self): with self._lock: for envelope in self._tx_ring: if envelope.packet is not None: @@ -191,14 +250,15 @@ class Channel(contextlib.AbstractContextManager): self._tx_ring.clear() self._rx_ring.clear() - def emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool: + def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool: with self._lock: i = 0 - for env in ring: - if env.sequence < envelope.sequence: + for existing in ring: + if existing.sequence > envelope.sequence \ + and not existing.sequence // 2 > envelope.sequence: # account for overflow ring.insert(i, envelope) return True - if env.sequence == envelope.sequence: + if existing.sequence == envelope.sequence: RNS.log(f"Envelope: Emplacement of duplicate envelope sequence.", RNS.LOG_EXTREME) return False i += 1 @@ -206,7 +266,7 @@ class Channel(contextlib.AbstractContextManager): ring.append(envelope) return True - def prune_rx_ring(self): + def _prune_rx_ring(self): with self._lock: # Implementation for fixed window = 1 stale = list(sorted(self._rx_ring, key=lambda env: env.sequence, reverse=True))[1:] @@ -225,13 +285,13 @@ class Channel(contextlib.AbstractContextManager): except Exception as ex: RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR) - def receive(self, raw: bytes): + def _receive(self, raw: bytes): try: envelope = Envelope(outlet=self._outlet, raw=raw) with self._lock: message = envelope.unpack(self._message_factories) - is_new = self.emplace_envelope(envelope, self._rx_ring) - self.prune_rx_ring() + is_new = self._emplace_envelope(envelope, self._rx_ring) + self._prune_rx_ring() if not is_new: RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) return @@ -241,6 +301,10 @@ class Channel(contextlib.AbstractContextManager): RNS.log(f"Channel: Error receiving data: {ex}") def is_ready_to_send(self) -> bool: + """ + Check if Channel is ready to send. + @return: True if ready + """ if not self._outlet.is_usable: RNS.log("Channel: Link is not usable.", RNS.LOG_EXTREME) return False @@ -273,7 +337,7 @@ class Channel(contextlib.AbstractContextManager): def retry_envelope(envelope: Envelope) -> bool: if envelope.tries >= self._max_tries: RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR) - self.shutdown() # start on separate thread? + self._shutdown() # start on separate thread? self._outlet.timed_out() return True envelope.tries += 1 @@ -283,13 +347,18 @@ class Channel(contextlib.AbstractContextManager): self._packet_tx_op(packet, retry_envelope) def send(self, message: MessageBase) -> Envelope: + """ + Send a message. If a message send is attempted and + Channel is not ready, an exception is thrown. + @param message: an instance of a MessageBase subclass to send on the Channel + """ envelope: Envelope | None = None with self._lock: if not self.is_ready_to_send(): raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready") envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence) self._next_sequence = (self._next_sequence + 1) % 0x10000 - self.emplace_envelope(envelope, self._tx_ring) + self._emplace_envelope(envelope, self._tx_ring) if envelope is None: raise BlockingIOError() @@ -304,10 +373,20 @@ class Channel(contextlib.AbstractContextManager): @property def MDU(self): + """ + Maximum Data Unit: the number of bytes available + for a message to consume in a single send. + @return: number of bytes available + """ return self._outlet.mdu - 6 # sizeof(msgtype) + sizeof(length) + sizeof(sequence) class LinkChannelOutlet(ChannelOutletBase): + """ + An implementation of ChannelOutletBase for RNS.Link. + Allows Channel to send packets over an RNS Link with + Packets. + """ def __init__(self, link: RNS.Link): self.link = link diff --git a/RNS/Link.py b/RNS/Link.py index 5f137d4..0f42388 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -464,7 +464,7 @@ class Link: for resource in self.outgoing_resources: resource.cancel() if self._channel: - self._channel.shutdown() + self._channel._shutdown() self.prv = None self.pub = None @@ -801,7 +801,7 @@ class Link: RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG) else: plaintext = self.decrypt(packet.data) - self._channel.receive(plaintext) + self._channel._receive(plaintext) packet.prove() elif packet.packet_type == RNS.Packet.PROOF: diff --git a/tests/channel.py b/tests/channel.py index 245789a..b1097bf 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -153,7 +153,7 @@ class ProtocolHarness(contextlib.AbstractContextManager): self.channel = Channel(self.outlet) def cleanup(self): - self.channel.shutdown() + self.channel._shutdown() def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, __traceback: types.TracebackType) -> bool: @@ -282,7 +282,7 @@ class TestChannel(unittest.TestCase): self.h.channel.add_message_handler(handler2) envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0) raw = envelope.pack() - self.h.channel.receive(raw) + self.h.channel._receive(raw) self.assertEqual(1, handler1_called) self.assertEqual(0, handler2_called) @@ -290,7 +290,7 @@ class TestChannel(unittest.TestCase): handler1_return = False envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1) raw = envelope.pack() - self.h.channel.receive(raw) + self.h.channel._receive(raw) self.assertEqual(2, handler1_called) self.assertEqual(1, handler2_called) @@ -348,7 +348,7 @@ class TestChannel(unittest.TestCase): self.assertFalse(envelope.tracked) self.assertEqual(0, len(decoded)) - self.h.channel.receive(packet.raw) + self.h.channel._receive(packet.raw) self.assertEqual(1, len(decoded))