diff --git a/nkms/crypto/utils.py b/nkms/crypto/utils.py index db9573bec..0acaee9b9 100644 --- a/nkms/crypto/utils.py +++ b/nkms/crypto/utils.py @@ -47,17 +47,33 @@ class BytestringSplitter(object): def __len__(self): return sum(self.get_message_meta(m)[1] for m in self.message_types) + @staticmethod def get_message_meta(message_type): - if isinstance(message_type, tuple): - message_meta = message_type - else: - try: - message_meta = message_type, message_type._EXPECTED_LENGTH - except AttributeError: - raise TypeError("No way to know the expected length. Either pass it as the second member of a tuple or set _EXPECTED_LENGTH on the class you're passing.") + try: + message_class = message_type[0] + except TypeError: + message_class = message_type - return message_meta + try: + message_length = message_type[1] + except TypeError: + message_length = message_type._EXPECTED_LENGTH + except AttributeError: + raise TypeError("No way to know the expected length. Either pass it as the second member of a tuple or set _EXPECTED_LENGTH on the class you're passing.") + + try: + kwargs = message_type[2] + except (IndexError, TypeError): + kwargs = {} + + return message_class, message_length, kwargs + + def __add__(self, splitter): + return self.__class__(*self.message_types + splitter.message_types) + + def __radd__(self, other): + return other + bytes(self) class RepeatingBytestringSplitter(BytestringSplitter):