diff --git a/util/mbedtls_maintainer/mldsa_test_generator.py b/util/mbedtls_maintainer/mldsa_test_generator.py index fed074577..c7ec7d591 100644 --- a/util/mbedtls_maintainer/mldsa_test_generator.py +++ b/util/mbedtls_maintainer/mldsa_test_generator.py @@ -85,6 +85,35 @@ class Generator: def secret_is_seed(cls) -> bool: return True + @staticmethod + def message_for_length(length: int) -> bytes: + # b'ABCDE...' wrapping around and repeating every 256 bytes + return bytes(b & 0xff for b in range(65, 65 + length)) + + def chunks_for_lengths(self, + lengths: Sequence[int], + arity: Optional[int] = None, + ) -> Tuple[bytes, List[bytes]]: + """Construct a message split in chunks of the given lengths. + + The content of the message only depends on the total length. + + If `arity` is specified and less than `len(lengths)`, pad the list + of chunks to that number. If `arity` is specified and larger than + `len(lengths)`, raise an exception. + """ + total_length = sum(lengths) + message = self.message_for_length(total_length) + chunks = [] + offset = 0 + for n in lengths: + chunks.append(message[offset:offset + n]) + offset += n + if arity is not None: + assert len(lengths) <= arity + chunks += [b''] * (arity - len(lengths)) + return (message, chunks) + def one_mldsa_public_key_from_seed(self, key: Key, descr: str) -> test_case.TestCase: """Construct one test case for driver export_public_key().""" @@ -263,10 +292,7 @@ class DriverGenerator(Generator): The number of message chunks must be at most MULTIPART_ARITY. """ - assert len(lengths) <= self.MULTIPART_ARITY - chunks = ([bytes(i) * n for i, n in enumerate(lengths, 65)] + - [b''] * (self.MULTIPART_ARITY - len(lengths))) - message = b''.join(chunks) + message, chunks = self.chunks_for_lengths(lengths, self.MULTIPART_ARITY) descr = '+'.join(map(str, lengths)) if lengths else 'empty (no update)' type_is_pair = not function.startswith('verif') deterministic = True