From 88cfd0cf476935cd677f6b30649aaea8d7b187a7 Mon Sep 17 00:00:00 2001 From: Gilles Peskine Date: Thu, 23 Apr 2026 15:43:05 +0200 Subject: [PATCH] Improve the construction of inputs to multipart APIs Fix a bug whereby the chunks did not actually have the desired lengths. Make the message content depend only on its length, and not how it is split into chunks. This way, it'll be easier to notice and analyze bugs that cause different outputs for different ways to split the input. Signed-off-by: Gilles Peskine --- .../mldsa_test_generator.py | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/util/mbedtls_maintainer/mldsa_test_generator.py b/util/mbedtls_maintainer/mldsa_test_generator.py index fed074577..c7ec7d591 100644 --- a/util/mbedtls_maintainer/mldsa_test_generator.py +++ b/util/mbedtls_maintainer/mldsa_test_generator.py @@ -85,6 +85,35 @@ class Generator: def secret_is_seed(cls) -> bool: return True + @staticmethod + def message_for_length(length: int) -> bytes: + # b'ABCDE...' wrapping around and repeating every 256 bytes + return bytes(b & 0xff for b in range(65, 65 + length)) + + def chunks_for_lengths(self, + lengths: Sequence[int], + arity: Optional[int] = None, + ) -> Tuple[bytes, List[bytes]]: + """Construct a message split in chunks of the given lengths. + + The content of the message only depends on the total length. + + If `arity` is specified and less than `len(lengths)`, pad the list + of chunks to that number. If `arity` is specified and larger than + `len(lengths)`, raise an exception. + """ + total_length = sum(lengths) + message = self.message_for_length(total_length) + chunks = [] + offset = 0 + for n in lengths: + chunks.append(message[offset:offset + n]) + offset += n + if arity is not None: + assert len(lengths) <= arity + chunks += [b''] * (arity - len(lengths)) + return (message, chunks) + def one_mldsa_public_key_from_seed(self, key: Key, descr: str) -> test_case.TestCase: """Construct one test case for driver export_public_key().""" @@ -263,10 +292,7 @@ class DriverGenerator(Generator): The number of message chunks must be at most MULTIPART_ARITY. """ - assert len(lengths) <= self.MULTIPART_ARITY - chunks = ([bytes(i) * n for i, n in enumerate(lengths, 65)] + - [b''] * (self.MULTIPART_ARITY - len(lengths))) - message = b''.join(chunks) + message, chunks = self.chunks_for_lengths(lengths, self.MULTIPART_ARITY) descr = '+'.join(map(str, lengths)) if lengths else 'empty (no update)' type_is_pair = not function.startswith('verif') deterministic = True