diff --git a/Examples/Channel.py b/Examples/Channel.py index 4f1bd2c..f64e427 100644 --- a/Examples/Channel.py +++ b/Examples/Channel.py @@ -46,7 +46,8 @@ class StringMessage(RNS.MessageBase): # message arrives over the channel. # # MSGTYPE must be unique across all message types we - # register with the channel + # register with the channel. MSGTYPEs >= 0xff00 are + # reserved for the system. MSGTYPE = 0x0101 # The constructor of our object must be callable with diff --git a/RNS/Channel.py b/RNS/Channel.py index 0b023be..31aaf94 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -149,7 +149,7 @@ class Channel(contextlib.AbstractContextManager): self.shutdown() return False - def register_message_type(self, message_class: Type[MessageBase]): + def register_message_type(self, message_class: Type[MessageBase], *, is_system_type: bool = False): with self._lock: if not issubclass(message_class, MessageBase): raise ChannelException(CEType.ME_INVALID_MSG_TYPE, @@ -157,6 +157,9 @@ class Channel(contextlib.AbstractContextManager): if message_class.MSGTYPE is None: raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.") + if message_class.MSGTYPE >= 0xff00 and not is_system_type: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} has system-reserved message type.") try: message_class() except Exception as ex: diff --git a/tests/channel.py b/tests/channel.py index c9a64b3..03e3bd9 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -137,6 +137,16 @@ class MessageTest(MessageBase): self.id, self.data = umsgpack.unpackb(raw) +class SystemMessage(MessageBase): + MSGTYPE = 0xffff + + def pack(self) -> bytes: + return bytes() + + def unpack(self, raw): + pass + + class ProtocolHarness(contextlib.AbstractContextManager): def __init__(self, rtt: float): self.outlet = ChannelOutletTest(mdu=500, rtt=rtt) @@ -280,6 +290,11 @@ class TestChannel(unittest.TestCase): self.assertEqual(2, handler1_called) self.assertEqual(1, handler2_called) + def test_system_message_check(self): + with self.assertRaises(RNS.Channel.ChannelException): + self.h.channel.register_message_type(SystemMessage) + self.h.channel.register_message_type(SystemMessage, is_system_type=True) + def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]): decoded: [MessageBase] = [] @@ -287,6 +302,7 @@ class TestChannel(unittest.TestCase): def handle_message(message: MessageBase): decoded.append(message) + self.h.channel.register_message_type(message.__class__) self.h.channel.add_message_handler(handle_message) self.assertEqual(len(self.h.outlet.packets), 0)