From f8d36c84ac486d7e504cc3d9a697dffcbb8f0aa1 Mon Sep 17 00:00:00 2001 From: Gilles Peskine Date: Mon, 20 Apr 2026 10:13:21 +0200 Subject: [PATCH 1/3] Speed up generation by caching signatures We tend to generate the signature of the same message under the same keys multiple times for different APIs (sign/verify, driver/dispatch, ...). Caching results makes the script noticeably faster. Signed-off-by: Gilles Peskine --- util/mbedtls_maintainer/mldsa_test_generator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/util/mbedtls_maintainer/mldsa_test_generator.py b/util/mbedtls_maintainer/mldsa_test_generator.py index 72b20ce35..023e44cd8 100644 --- a/util/mbedtls_maintainer/mldsa_test_generator.py +++ b/util/mbedtls_maintainer/mldsa_test_generator.py @@ -4,6 +4,7 @@ # Copyright The Mbed TLS Contributors # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later +import functools from typing import Iterator, List, Optional # pip install dilithium-py @@ -41,6 +42,7 @@ class Key: self.seed = seed self.public, self.secret = PURE[kl]._keygen_internal(seed) + @functools.lru_cache(maxsize=9999) def sign_message(self, message: bytes, deterministic: bool) -> bytes: PURE[self.kl].set_drbg_seed(bytes(48)) return PURE[self.kl].sign(self.secret, message, From ffc707d4f6a8ecd0974edf886303904b811378d2 Mon Sep 17 00:00:00 2001 From: Gilles Peskine Date: Mon, 20 Apr 2026 15:17:25 +0200 Subject: [PATCH 2/3] Support generating multipart tests for driver and dispatch Not done by default for smooth transition in the consuming TF-PSA-Crypto branch. It's up to the calling script `generate_mldsa_tests.py` to enable the new test cases. Signed-off-by: Gilles Peskine --- .../mldsa_test_generator.py | 79 ++++++++++++++++++- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/util/mbedtls_maintainer/mldsa_test_generator.py b/util/mbedtls_maintainer/mldsa_test_generator.py index 023e44cd8..fed074577 100644 --- a/util/mbedtls_maintainer/mldsa_test_generator.py +++ b/util/mbedtls_maintainer/mldsa_test_generator.py @@ -4,8 +4,9 @@ # Copyright The Mbed TLS Contributors # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later +import collections import functools -from typing import Iterator, List, Optional +from typing import Callable, Iterator, List, Optional, Sequence, Tuple # pip install dilithium-py import dilithium_py.ml_dsa #type: ignore @@ -244,14 +245,86 @@ class DriverGenerator(Generator): return arguments @classmethod - def final_arguments(cls) -> List[str]: - return ['PSA_SUCCESS'] + def final_arguments(cls, status: str = 'PSA_SUCCESS') -> List[str]: + return [status] def gen_key_management(self, kl: int) -> Iterator[test_case.TestCase]: """Generate test cases for driver export_public_key().""" for i, key in enumerate(KEYS[kl], 1): yield self.one_mldsa_public_key_from_seed(key, f'key#{i}') + MULTIPART_ARITY = 3 + + def one_multipart(self, function: str, key: Key, + lengths: Sequence[int], + tweak_signature: Optional[Callable[[bytes], Tuple[bytes, str, str]]] = None, + ) -> test_case.TestCase: + """Construct one test case for a multipart operation. + + The number of message chunks must be at most MULTIPART_ARITY. + """ + assert len(lengths) <= self.MULTIPART_ARITY + chunks = ([bytes(i) * n for i, n in enumerate(lengths, 65)] + + [b''] * (self.MULTIPART_ARITY - len(lengths))) + message = b''.join(chunks) + descr = '+'.join(map(str, lengths)) if lengths else 'empty (no update)' + type_is_pair = not function.startswith('verif') + deterministic = True + actual_signature = key.sign_message(message, deterministic=deterministic) + signature = actual_signature + status = 'PSA_SUCCESS' + more_descr = '' + if tweak_signature is not None: + signature, status, more_descr = tweak_signature(actual_signature) + if more_descr: + more_descr = ', ' + more_descr + tc = test_case.TestCase() + tc.set_function(self.function(function + '_multipart', key.kl)) + tc.set_dependencies([f'TF_PSA_CRYPTO_PQCP_MLDSA_{key.kl}_ENABLED']) + tc.set_arguments(self.metadata_arguments(key.kl, type_is_pair, True) + [ + test_case.hex_string(key.seed if type_is_pair else key.public), + str(len(lengths)), + ] + [test_case.hex_string(chunk) for chunk in chunks] + [ + test_case.hex_string(signature), + ] + self.final_arguments(status=status)) + tc.set_description(f'MLDSA-{key.kl} {function} multipart {descr}{more_descr}') + return tc + + MANY_MULTIPART_LENGTHS: Sequence[Sequence[int]] = [ + [], [0], [0, 0], + [1], + [3], [1, 2], [2, 1], [1, 1, 1], + [42], [0, 42], [42, 0], [41, 1], + [300], [100, 200], [200, 100], [100, 100, 100], + ] + FEW_MULTIPART_LENGTHS: Sequence[Sequence[int]] = [[], [42], [41, 1]] + + VERIFY_TWEAKS = collections.OrderedDict([ + ('sig=empty', lambda sig: b''), + ('truncated sig', lambda sig: sig[:-1]), + ('sig+garbage', lambda sig: sig + b'\x00'), + ('sig[-1]^=1', lambda sig: sig[:-1] + bytes([sig[-1] ^ 1])), + ('sig[0]^=1', lambda sig: bytes([sig[0] ^ 1]) + sig[1:]), + ]) + + def gen_multipart(self, key: Key) -> Iterator[test_case.TestCase]: + """Generate test cases for multipart sign and verify.""" + for lengths in self.MANY_MULTIPART_LENGTHS: + yield self.one_multipart('sign_deterministic', key, lengths) + yield self.one_multipart('verify', key, lengths) + for descr, func in self.VERIFY_TWEAKS.items(): + for lengths in self.FEW_MULTIPART_LENGTHS: + tweak = (lambda sig, func=func: + (func(sig), 'PSA_ERROR_INVALID_SIGNATURE', descr)) + yield self.one_multipart('verify', key, lengths, + tweak_signature=tweak) + + def gen_all(self, multipart: bool = False) -> Iterator[test_case.TestCase]: + """Generate all the tests for this API.""" + yield from super().gen_all() + if multipart: + for kl in sorted(KEYS.keys()): + yield from self.gen_multipart(KEYS[kl][0]) class DispatchGenerator(DriverGenerator): """Test the driver dispatch layer.""" From 88cfd0cf476935cd677f6b30649aaea8d7b187a7 Mon Sep 17 00:00:00 2001 From: Gilles Peskine Date: Thu, 23 Apr 2026 15:43:05 +0200 Subject: [PATCH 3/3] Improve the construction of inputs to multipart APIs Fix a bug whereby the chunks did not actually have the desired lengths. Make the message content depend only on its length, and not how it is split into chunks. This way, it'll be easier to notice and analyze bugs that cause different outputs for different ways to split the input. Signed-off-by: Gilles Peskine --- .../mldsa_test_generator.py | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/util/mbedtls_maintainer/mldsa_test_generator.py b/util/mbedtls_maintainer/mldsa_test_generator.py index fed074577..c7ec7d591 100644 --- a/util/mbedtls_maintainer/mldsa_test_generator.py +++ b/util/mbedtls_maintainer/mldsa_test_generator.py @@ -85,6 +85,35 @@ class Generator: def secret_is_seed(cls) -> bool: return True + @staticmethod + def message_for_length(length: int) -> bytes: + # b'ABCDE...' wrapping around and repeating every 256 bytes + return bytes(b & 0xff for b in range(65, 65 + length)) + + def chunks_for_lengths(self, + lengths: Sequence[int], + arity: Optional[int] = None, + ) -> Tuple[bytes, List[bytes]]: + """Construct a message split in chunks of the given lengths. + + The content of the message only depends on the total length. + + If `arity` is specified and less than `len(lengths)`, pad the list + of chunks to that number. If `arity` is specified and larger than + `len(lengths)`, raise an exception. + """ + total_length = sum(lengths) + message = self.message_for_length(total_length) + chunks = [] + offset = 0 + for n in lengths: + chunks.append(message[offset:offset + n]) + offset += n + if arity is not None: + assert len(lengths) <= arity + chunks += [b''] * (arity - len(lengths)) + return (message, chunks) + def one_mldsa_public_key_from_seed(self, key: Key, descr: str) -> test_case.TestCase: """Construct one test case for driver export_public_key().""" @@ -263,10 +292,7 @@ class DriverGenerator(Generator): The number of message chunks must be at most MULTIPART_ARITY. """ - assert len(lengths) <= self.MULTIPART_ARITY - chunks = ([bytes(i) * n for i, n in enumerate(lengths, 65)] + - [b''] * (self.MULTIPART_ARITY - len(lengths))) - message = b''.join(chunks) + message, chunks = self.chunks_for_lengths(lengths, self.MULTIPART_ARITY) descr = '+'.join(map(str, lengths)) if lengths else 'empty (no update)' type_is_pair = not function.startswith('verif') deterministic = True